From 21f5b618c3c583d53036062adb1bf7e90824644f Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Tue, 14 Feb 2023 19:46:27 +0900 Subject: [PATCH 01/10] Show the moving average loss --- train_db.py | 11 +++++++---- train_network.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/train_db.py b/train_db.py index c210767b..a3154cd1 100644 --- a/train_db.py +++ b/train_db.py @@ -206,6 +206,7 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("dreambooth") + loss_list = [] for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -216,7 +217,6 @@ def train(args): if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -291,8 +291,11 @@ def train(args): logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} accelerator.log(logs, step=global_step) - loss_total += current_loss - avr_loss = loss_total / (step+1) + if epoch == 0: + loss_list.append(current_loss) + else: + loss_list[step] = current_loss + avr_loss = sum(loss_list) / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -300,7 +303,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"epoch_loss": sum(loss_list) / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index bb3159fd..c9c3c468 100644 --- a/train_network.py +++ b/train_network.py @@ -378,6 +378,7 @@ def train(args): if accelerator.is_main_process: accelerator.init_trackers("network_train") + loss_list = [] for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -386,7 +387,6 @@ def train(args): network.on_epoch_start(text_encoder, unet) - loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): @@ -446,8 +446,11 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - loss_total += current_loss - avr_loss = loss_total / (step+1) + if epoch == 0: + loss_list.append(current_loss) + else: + loss_list[step] = current_loss + avr_loss = sum(loss_list) / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -459,7 +462,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": sum(loss_list) / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() From 8aed5125deddef0a73fcb9a84c3bacefb0059a11 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Tue, 14 Feb 2023 21:11:30 +0900 Subject: [PATCH 02/10] Removed call of sum() --- train_db.py | 7 +++++-- train_network.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/train_db.py b/train_db.py index a3154cd1..cbcd8071 100644 --- a/train_db.py +++ b/train_db.py @@ -207,6 +207,7 @@ def train(args): accelerator.init_trackers("dreambooth") loss_list = [] + loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -294,8 +295,10 @@ def train(args): if epoch == 0: loss_list.append(current_loss) else: + loss_total -= loss_list[step] loss_list[step] = current_loss - avr_loss = sum(loss_list) / len(loss_list) + loss_total += current_loss + avr_loss = loss_total / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -303,7 +306,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": sum(loss_list) / len(loss_list)} + logs = {"epoch_loss": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index c9c3c468..91d2a0bc 100644 --- a/train_network.py +++ b/train_network.py @@ -379,6 +379,7 @@ def train(args): accelerator.init_trackers("network_train") loss_list = [] + loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) @@ -449,8 +450,10 @@ def train(args): if epoch == 0: loss_list.append(current_loss) else: + loss_total -= loss_list[step] loss_list[step] = current_loss - avr_loss = sum(loss_list) / len(loss_list) + loss_total += current_loss + avr_loss = loss_total / len(loss_list) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -462,7 +465,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"loss/epoch": sum(loss_list) / len(loss_list)} + logs = {"loss/epoch": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() From 496c8cdc098262b5e4796167bf15367daa6ca47b Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 16 Feb 2023 02:56:39 -0800 Subject: [PATCH 03/10] Add noise-offset to metadata --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 1b8046d2..d66f9095 100644 --- a/train_network.py +++ b/train_network.py @@ -353,6 +353,7 @@ def train(args): "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_keep_tokens": args.keep_tokens, + "ss_noise_offset": args.noise_offset, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), From ffdfd5f6153280afff5929a75c1fc4321bfdfc91 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 16 Feb 2023 22:21:36 +0900 Subject: [PATCH 04/10] fix name of loss for epoch --- train_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_db.py b/train_db.py index d36bd8d0..e4f1e54c 100644 --- a/train_db.py +++ b/train_db.py @@ -309,7 +309,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() From 3bc0d83769ca7fcf322be3c1f9a025f0dc375880 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 16 Feb 2023 22:21:51 +0900 Subject: [PATCH 05/10] update readme --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 62551f27..03ee5d01 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 16 Feb. 2023, 2023/2/16: + - Noise offset is recorded to the metadata. Thanks to space-nuko! + - Show the moving average loss to prevent loss jumping in ``train_network.py`` and ``train_db.py``. Thanks to shirayu! + - Noise offsetがメタデータに記録されるようになりました。space-nuko氏に感謝します。 + - ``train_network.py``と``train_db.py``で学習中に表示されるlossの値が移動平均になりました。epochの先頭で表示されるlossが大きく変動する事象を解決します。shirayu氏に感謝します。 - 14 Feb. 2023, 2023/2/14: - Add support with multi-gpu trainining for ``train_network.py``. Thanks to Isotr0py! - Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](https://github.com/kohya-ss/sd-scripts/pull/179). Thanks to mgz-dev! From 78d1fb5ce65af2f96fcf8d4bc2ef18caee1172ba Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Feb 2023 12:08:54 +0800 Subject: [PATCH 06/10] Add '--lowram' argument --- library/train_util.py | 4 +++- train_network.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 415f9b70..1a42d591 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1423,7 +1423,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") parser.add_argument("--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") - + parser.add_argument("--lowram", action="store_true", + help="load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)") + if support_dreambooth: # DreamBooth training parser.add_argument("--prior_loss_weight", type=float, default=1.0, diff --git a/train_network.py b/train_network.py index 5983a7ef..e29e0174 100644 --- a/train_network.py +++ b/train_network.py @@ -156,9 +156,10 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - # unnecessary, but work on low-ram device - text_encoder.to("cuda") - unet.to("cuda") + # work on low-ram device + if args.lowram: + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) From dac2bd163ae497fa5f2002739ac495f1ed286080 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 17 Feb 2023 14:19:08 -0500 Subject: [PATCH 07/10] fix git path --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 415f9b70..0668bd7f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1103,7 +1103,7 @@ def addnet_hash_safetensors(b): def get_git_revision_hash() -> str: try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip() except: return "(unknown)" From a76ad2d1d5888d7e1c1bddbe751f6aead950d8ce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Feb 2023 15:25:01 +0900 Subject: [PATCH 08/10] add comment for future requirement update --- README-ja.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README-ja.md b/README-ja.md index adf44d2f..064464c0 100644 --- a/README-ja.md +++ b/README-ja.md @@ -64,6 +64,12 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set accelerate config ``` + + コマンドプロンプトでは以下になります。 From 048e7cd4283d14f969e55da46cc21d66be0bfb53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Feb 2023 15:26:14 +0900 Subject: [PATCH 09/10] add lion optimizer support --- fine_tune.py | 7 +++++++ library/train_util.py | 6 ++++-- train_db.py | 7 +++++++ train_network.py | 9 +++++++++ train_textual_inversion.py | 7 +++++++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 3ba63063..13241bc6 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -158,6 +158,13 @@ def train(args): raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") print("use 8-bit Adam optimizer") optimizer_class = bnb.optim.AdamW8bit + elif args.use_lion_optimizer: + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print("use Lion optimizer") + optimizer_class = lion_pytorch.Lion else: optimizer_class = torch.optim.AdamW diff --git a/library/train_util.py b/library/train_util.py index 441838e5..63868f98 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1389,6 +1389,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") parser.add_argument("--use_8bit_adam", action="store_true", help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--use_lion_optimizer", action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") parser.add_argument("--mem_eff_attn", action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") parser.add_argument("--xformers", action="store_true", @@ -1424,8 +1426,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") parser.add_argument("--lowram", action="store_true", - help="load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)") - + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)") + if support_dreambooth: # DreamBooth training parser.add_argument("--prior_loss_weight", type=float, default=1.0, diff --git a/train_db.py b/train_db.py index e4f1e54c..1903c4c4 100644 --- a/train_db.py +++ b/train_db.py @@ -124,6 +124,13 @@ def train(args): raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") print("use 8-bit Adam optimizer") optimizer_class = bnb.optim.AdamW8bit + elif args.use_lion_optimizer: + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print("use Lion optimizer") + optimizer_class = lion_pytorch.Lion else: optimizer_class = torch.optim.AdamW diff --git a/train_network.py b/train_network.py index e29e0174..b41a52a9 100644 --- a/train_network.py +++ b/train_network.py @@ -156,10 +156,12 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + # work on low-ram device if args.lowram: text_encoder.to("cuda") unet.to("cuda") + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -214,6 +216,13 @@ def train(args): raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") print("use 8-bit Adam optimizer") optimizer_class = bnb.optim.AdamW8bit + elif args.use_lion_optimizer: + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print("use Lion optimizer") + optimizer_class = lion_pytorch.Lion else: optimizer_class = torch.optim.AdamW diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 010bd04b..ffec0516 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -207,6 +207,13 @@ def train(args): raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") print("use 8-bit Adam optimizer") optimizer_class = bnb.optim.AdamW8bit + elif args.use_lion_optimizer: + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print("use Lion optimizer") + optimizer_class = lion_pytorch.Lion else: optimizer_class = torch.optim.AdamW From 5c065eee79fca0e9ac6ff5ec0432bb357a981bf9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Feb 2023 15:26:21 +0900 Subject: [PATCH 10/10] update readme --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 03ee5d01..a1adcb27 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,19 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 19 Feb. 2023, 2023/2/19: + - Add ``--use_lion_optimizer`` to each training script to use [Lion optimizer](https://github.com/lucidrains/lion-pytorch). + - Please install Lion optimizer with ``pip install lion-pytorch`` (it is not in ``requirements.txt`` currently.) + - Add ``--lowram`` option to ``train_network.py``. Load models to VRAM instead of VRAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle). Thanks to Isotr0py! + - Default behavior (without lowram) has reverted to the same as before 14 Feb. + - Fixed git commit hash to be set correctly regardless of the working directory. Thanks to vladmandic! + + - ``--use_lion_optimizer`` オプションを各学習スクリプトに追加しました。 [Lion optimizer](https://github.com/lucidrains/lion-pytorch) を使用できます。 + - あらかじめ ``pip install lion-pytorch`` でインストールしてください(現在は ``requirements.txt`` に含まれていません)。 + - ``--lowram`` オプションを ``train_network.py`` に追加しました。モデルをRAMではなくVRAMに読み込みます(ColabやKaggleなど、VRAMがRAMに比べて多い環境で有効です)。 Isotr0py 氏に感謝します。 + - lowram オプションなしのデフォルト動作は2/14より前と同じに戻しました。 + - git commit hash を現在のフォルダ位置に関わらず正しく取得するように修正しました。vladmandic 氏に感謝します。 + - 16 Feb. 2023, 2023/2/16: - Noise offset is recorded to the metadata. Thanks to space-nuko! - Show the moving average loss to prevent loss jumping in ``train_network.py`` and ``train_db.py``. Thanks to shirayu!