From 0fef7b46841da2f6b06650ec506ca00343eeee12 Mon Sep 17 00:00:00 2001 From: michaelgzhang <49577754+mgz-dev@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:42:11 -0600 Subject: [PATCH 1/5] monkeypatch updated get_scheduler for diffusers enables use of "num_cycles" and "power" for cosine_with_restarts and polynomial learning rate schedulers --- train_network.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 8a8acc7d..6bc8bd08 100644 --- a/train_network.py +++ b/train_network.py @@ -35,6 +35,75 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche return logs +# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler +# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 +# Which is a newer release of diffusers than currently packaged with sd-scripts +# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts + +from typing import Optional, Union +from torch.optim import Optimizer +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION + +def get_scheduler_fix( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, +): + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + +diffusers.optimization.get_scheduler = get_scheduler_fix + + + def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -157,7 +226,9 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles = args.num_cycles, power = args.power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -460,6 +531,10 @@ if __name__ == '__main__': help="only training Text Encoder part / Text Encoder関連部分のみ学習する") parser.add_argument("--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") + parser.add_argument("--num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts") + parser.add_argument("--power", type=float, default=1, + help="Polynomial power for polynomial scheduler") args = parser.parse_args() train(args) From 7817e95a86a63777ead115d6ee930cfee20eca33 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jan 2023 20:28:24 +0900 Subject: [PATCH 2/5] change name of arg --- train_network.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/train_network.py b/train_network.py index 6bc8bd08..37a10f65 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,6 @@ def get_scheduler_fix( return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) -diffusers.optimization.get_scheduler = get_scheduler_fix - - def train(args): session_id = random.randint(0, 2**32) @@ -225,10 +222,11 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( + # lr_scheduler = diffusers.optimization.get_scheduler( + lr_scheduler = get_scheduler_fix( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles = args.num_cycles, power = args.power) + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -516,6 +514,10 @@ if __name__ == '__main__': parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") + parser.add_argument("--lr_scheduler_power", type=float, default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") @@ -531,10 +533,6 @@ if __name__ == '__main__': help="only training Text Encoder part / Text Encoder関連部分のみ学習する") parser.add_argument("--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") - parser.add_argument("--num_cycles", type=int, default=1, - help="Number of restarts for cosine scheduler with restarts") - parser.add_argument("--power", type=float, default=1, - help="Polynomial power for polynomial scheduler") args = parser.parse_args() train(args) From 6bbb4d426efb97607dd2d587412750da10164d17 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jan 2023 20:43:58 +0900 Subject: [PATCH 3/5] Fix unet config in Diffusers (sample_size=64) --- library/model_util.py | 2 +- tools/convert_diffusers20_original_sd.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index 6a1e656a..778a4c27 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -16,7 +16,7 @@ BETA_END = 0.0120 UNET_PARAMS_MODEL_CHANNELS = 320 UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 32 # unused +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` UNET_PARAMS_IN_CHANNELS = 4 UNET_PARAMS_OUT_CHANNELS = 4 UNET_PARAMS_NUM_RES_BLOCKS = 2 diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index a3cd03fe..6c142848 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -1,8 +1,4 @@ # convert Diffusers v1.x/v2.0 model to original Stable Diffusion -# v1: initial version -# v2: support safetensors -# v3: fix to support another format -# v4: support safetensors in Diffusers import argparse import os From 86eba1d2cfad38a3ddb61d826787fd324e6d83d8 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Jan 2023 21:23:05 +0900 Subject: [PATCH 4/5] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 70fdebf1..4497a3c2 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,11 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Note: The LoRA models for SD 2.x is not supported too in Web UI. +- 29 Jan. 2023, 2023/1/29 + - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. + - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` + - ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。 + - ``convert_diffusers20_original_sd.py`` で SD 形式から Diffusers に変換するときの U-Net の ``sample_size`` パラメータを ``64`` に修正しました。 - 26 Jan. 2023, 2023/1/26 - Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.) - Textual Inversionの学習をサポートしました。ドキュメントは[こちら](./train_ti_README-ja.md)。 From 4cabb3797731560bf71b54967e285f0c0c34faf6 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Jan 2023 21:50:17 +0900 Subject: [PATCH 5/5] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4497a3c2..31c08a6e 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Note: The LoRA models for SD 2.x is not supported too in Web UI. - 29 Jan. 2023, 2023/1/29 - - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. + - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` - - ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。 + - ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。mgz-dev氏に感謝します。 - ``convert_diffusers20_original_sd.py`` で SD 形式から Diffusers に変換するときの U-Net の ``sample_size`` パラメータを ``64`` に修正しました。 - 26 Jan. 2023, 2023/1/26 - Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)