diff --git a/fine_tune.py b/fine_tune.py index 2b5255dc..d927bd73 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -198,14 +198,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, - optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, - power=args.lr_scheduler_power, - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/library/train_util.py b/library/train_util.py index 9f541b6c..d63a29ae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,6 +1,7 @@ # common functions for training import argparse +import ast import importlib import json import pathlib @@ -1720,6 +1721,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + parser.add_argument("--lr_scheduler_type", type=str, default="", + help="custom scheduler module") + parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*', + help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") + parser.add_argument( "--lr_scheduler", type=str, @@ -2284,14 +2290,7 @@ def get_optimizer(args, trainable_params): # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -def get_scheduler_fix( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, -): +def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int): """ Unified API to get any scheduler from its name. Args: @@ -2312,6 +2311,44 @@ def get_scheduler_fix( last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ + name = args.lr_scheduler + num_warmup_steps = args.lr_warmup_steps + num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split('=') + + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = (value[i].lower() == "true") + else: + value[i] = ast.literal_eval(value[i]) + if len(value) == 1: + value = value[0] + else: + value = list(value) # some may use list? + + lr_scheduler_kwargs[key] = value + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler + if name.startswith("adafactor"): assert ( type(optimizer) == transformers.optimization.Adafactor diff --git a/train_db.py b/train_db.py index c812bbc7..88ddcc99 100644 --- a/train_db.py +++ b/train_db.py @@ -166,14 +166,7 @@ def train(args): args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, - optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_scheduler_num_cycles, - power=args.lr_scheduler_power, - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/train_network.py b/train_network.py index ca0da112..7f910df4 100644 --- a/train_network.py +++ b/train_network.py @@ -201,14 +201,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, - optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, - power=args.lr_scheduler_power, - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f591dea1..e4ab7b5c 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -261,14 +261,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args.lr_scheduler, - optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, - power=args.lr_scheduler_power, - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(