diff --git a/library/train_util.py b/library/train_util.py index d63a29ae..af0ed09c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1721,11 +1721,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) - parser.add_argument("--lr_scheduler_type", type=str, default="", - help="custom scheduler module") - parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*', - help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") - + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', + ) + parser.add_argument( "--lr_scheduler", type=str, @@ -2083,7 +2087,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value - + config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] @@ -2290,26 +2294,9 @@ def get_optimizer(args, trainable_params): # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int): +def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. """ name = args.lr_scheduler num_warmup_steps = args.lr_warmup_steps @@ -2319,35 +2306,35 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int): lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: - for arg in args.lr_scheduler_args: - key, value = arg.split('=') + for arg in args.lr_scheduler_args: + key, value = arg.split("=") - value = value.split(",") - for i in range(len(value)): - if value[i].lower() == "true" or value[i].lower() == "false": - value[i] = (value[i].lower() == "true") - else: - value[i] = ast.literal_eval(value[i]) - if len(value) == 1: - value = value[0] - else: - value = list(value) # some may use list? + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = value[i].lower() == "true" + else: + value[i] = ast.literal_eval(value[i]) + if len(value) == 1: + value = value[0] + else: + value = list(value) # some may use list? - lr_scheduler_kwargs[key] = value + lr_scheduler_kwargs[key] = value # using any lr_scheduler from other library if args.lr_scheduler_type: - lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") - if "." not in lr_scheduler_type: # default to use torch.optim - lr_scheduler_module = torch.optim.lr_scheduler - else: - values = lr_scheduler_type.split(".") - lr_scheduler_module = importlib.import_module(".".join(values[:-1])) - lr_scheduler_type = values[-1] - lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) - lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) - return lr_scheduler + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler if name.startswith("adafactor"): assert ( diff --git a/train_db.py b/train_db.py index 88ddcc99..81aeda19 100644 --- a/train_db.py +++ b/train_db.py @@ -166,7 +166,7 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: