fix num_processes, fix indent

This commit is contained in:
Kohya S
2023-03-19 10:52:46 +09:00
parent 8f08feb577
commit 64d85b2f51
2 changed files with 36 additions and 49 deletions

View File

@@ -1721,10 +1721,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."', help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
) )
parser.add_argument("--lr_scheduler_type", type=str, default="", parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
help="custom scheduler module") parser.add_argument(
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*', "--lr_scheduler_args",
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"") type=str,
default=None,
nargs="*",
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100"',
)
parser.add_argument( parser.add_argument(
"--lr_scheduler", "--lr_scheduler",
@@ -2290,26 +2294,9 @@ def get_optimizer(args, trainable_params):
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int): def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
""" """
name = args.lr_scheduler name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps num_warmup_steps = args.lr_warmup_steps
@@ -2320,12 +2307,12 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
for arg in args.lr_scheduler_args: for arg in args.lr_scheduler_args:
key, value = arg.split('=') key, value = arg.split("=")
value = value.split(",") value = value.split(",")
for i in range(len(value)): for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false": if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true") value[i] = value[i].lower() == "true"
else: else:
value[i] = ast.literal_eval(value[i]) value[i] = ast.literal_eval(value[i])
if len(value) == 1: if len(value) == 1:

View File

@@ -166,7 +166,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16: