From 0fef7b46841da2f6b06650ec506ca00343eeee12 Mon Sep 17 00:00:00 2001 From: michaelgzhang <49577754+mgz-dev@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:42:11 -0600 Subject: [PATCH] monkeypatch updated get_scheduler for diffusers enables use of "num_cycles" and "power" for cosine_with_restarts and polynomial learning rate schedulers --- train_network.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 8a8acc7d..6bc8bd08 100644 --- a/train_network.py +++ b/train_network.py @@ -35,6 +35,75 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche return logs +# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler +# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 +# Which is a newer release of diffusers than currently packaged with sd-scripts +# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts + +from typing import Optional, Union +from torch.optim import Optimizer +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + +def get_scheduler_fix( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, +): + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + +diffusers.optimization.get_scheduler = get_scheduler_fix + + + def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -157,7 +226,9 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles = args.num_cycles, power = args.power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -460,6 +531,10 @@ if __name__ == '__main__': help="only training Text Encoder part / Text Encoder関連部分のみ学習する") parser.add_argument("--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") + parser.add_argument("--num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts") + parser.add_argument("--power", type=float, default=1, + help="Polynomial power for polynomial scheduler") args = parser.parse_args() train(args)