fix num_processes, fix indent

This commit is contained in:
Kohya S
2023-03-19 10:52:46 +09:00
parent 8f08feb577
commit 64d85b2f51
2 changed files with 36 additions and 49 deletions

View File

@@ -1721,10 +1721,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
)
parser.add_argument("--lr_scheduler_type", type=str, default="",
help="custom scheduler module")
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
parser.add_argument(
"--lr_scheduler_args",
type=str,
default=None,
nargs="*",
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100"',
)
parser.add_argument(
"--lr_scheduler",
@@ -2290,26 +2294,9 @@ def get_optimizer(args, trainable_params):
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
@@ -2319,35 +2306,35 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
for arg in args.lr_scheduler_args:
key, value = arg.split('=')
for arg in args.lr_scheduler_args:
key, value = arg.split("=")
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true")
else:
value[i] = ast.literal_eval(value[i])
if len(value) == 1:
value = value[0]
else:
value = list(value) # some may use list?
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = value[i].lower() == "true"
else:
value[i] = ast.literal_eval(value[i])
if len(value) == 1:
value = value[0]
else:
value = list(value) # some may use list?
lr_scheduler_kwargs[key] = value
lr_scheduler_kwargs[key] = value
# using any lr_scheduler from other library
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
if "." not in lr_scheduler_type: # default to use torch.optim
lr_scheduler_module = torch.optim.lr_scheduler
else:
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
lr_scheduler_type = args.lr_scheduler_type
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
if "." not in lr_scheduler_type: # default to use torch.optim
lr_scheduler_module = torch.optim.lr_scheduler
else:
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
if name.startswith("adafactor"):
assert (

View File

@@ -166,7 +166,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: