From 82aac2646945c8b483b9272d712eca16fe60b7aa Mon Sep 17 00:00:00 2001 From: rvhfxb <116002789+rvhfxb@users.noreply.github.com> Date: Wed, 8 Mar 2023 22:42:41 +0900 Subject: [PATCH 01/12] Update train_util.py --- library/train_util.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 75176e13..f1060cbb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -906,10 +906,14 @@ class FineTuningDataset(BaseDataset): if os.path.exists(image_key): abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(glob.escape(train_data_dir), image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] caption = img_md.get('caption') tags = img_md.get('tags') From 68cd874bb68f30e792243659fbeb8c98733cd365 Mon Sep 17 00:00:00 2001 From: mio <74481573+mio2333@users.noreply.github.com> Date: Fri, 10 Mar 2023 18:29:34 +0800 Subject: [PATCH 02/12] Append sys path for import_module This will be better if we run the scripts we do not run the training script from the current directory. This is reasonable as some other projects will use this as a subfolder, such as https://github.com/ddPn08/kohya-sd-scripts-webui. I can not run the script without adding this. --- train_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_network.py b/train_network.py index cf64c894..8be8305c 100644 --- a/train_network.py +++ b/train_network.py @@ -134,6 +134,8 @@ def train(args): gc.collect() # prepare network + import sys + sys.path.append(os.path.dirname(__file__)) print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) From 4ad8e75291ce77974b6441c9710a459cc95ee802 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 21:10:22 +0900 Subject: [PATCH 03/12] fix to work with dim>320 --- networks/resize_lora.py | 1 - networks/svd_merge_lora.py | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 1a8110c4..dfacd666 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -1,6 +1,5 @@ # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo and kohya import argparse import torch diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c8e39b80..3a03b0d5 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype): return sd -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, state_dict, dtype): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) + save_file(state_dict, file_name) else: - torch.save(model, file_name) + torch.save(state_dict, file_name) def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): @@ -105,6 +105,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty mat = mat.squeeze() module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim U, S, Vh = torch.linalg.svd(mat) @@ -156,7 +157,7 @@ def merge(args): state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + save_to_file(args.save_to, state_dict, save_dtype) if __name__ == '__main__': From 75d1883da630c033841ff7fc79a94ca7131dd3d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 21:12:15 +0900 Subject: [PATCH 04/12] fix LoRA rank is limited to target dim --- networks/extract_lora_from_models.py | 33 +++++++------------ networks/lora.py | 49 +++++++++++++++++----------- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 5d77b9e5..b5d18d9b 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -103,7 +103,8 @@ def svd(args): if args.device: mat = mat.to(args.device) - # print(mat.size(), mat.device, rank, in_dim, out_dim) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -137,27 +138,17 @@ def svd(args): lora_weights[lora_name] = (U, Vh) # make state dict for LoRA - lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict - lora_sd = lora_network_o.state_dict() - print(f"LoRA has {len(lora_sd)} weights.") - - for key in list(lora_sd.keys()): - if "alpha" in key: - continue - - lora_name = key.split('.')[0] - i = 0 if "lora_up" in key else 1 - - weights = lora_weights[lora_name][i] - # print(key, i, weights.size(), lora_sd[key].size()) - # if len(lora_sd[key].size()) == 4: - # weights = weights.unsqueeze(2).unsqueeze(3) - - assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" - lora_sd[key] = weights + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + '.lora_up.weight'] = up_weight + lora_sd[lora_name + '.lora_down.weight'] = down_weight + lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - info = lora_network_o.load_state_dict(lora_sd) + lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(args.save_to) @@ -167,7 +158,7 @@ def svd(args): # minimum metadata metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} - lora_network_o.save_weights(args.save_to, save_dtype, metadata) + lora_network_save.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index c0181c02..6d3875dc 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module): """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name - self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features - self.lora_dim = min(self.lora_dim, in_dim, out_dim) - if self.lora_dim != lora_dim: - print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + if org_module.__class__.__name__ == 'Conv2d': kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - in_dim = org_module.in_features - out_dim = org_module.out_features - self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = lora_dim if alpha is None or alpha == 0 else alpha + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un return network -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') # get dim/alpha mapping modules_dim = {} @@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa # support old LoRA without alpha for key in modules_dim.keys(): if key not in modules_alpha: - modules_alpha = modules_dim[key] + modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd @@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa class LoRANetwork(torch.nn.Module): # is it possible to apply conv_in and conv_out? - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' @@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module): text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") self.weights_sd = None @@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) - @staticmethod + @ staticmethod def set_regions(networks, image): image = image.astype(np.float32) / 255.0 for i, network in enumerate(networks[:3]): From 618592c52b9e82f7abc78105ecc23014a3505b19 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 21:31:59 +0900 Subject: [PATCH 05/12] npz check to use subset, add dadap warn close #274 --- library/train_util.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 68bce108..718fe36d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -912,7 +912,7 @@ class FineTuningDataset(BaseDataset): if os.path.exists(image_key): abs_path = image_key else: - npz_path = os.path.join(glob.escape(train_data_dir), image_key + ".npz") + npz_path = os.path.join(subset.image_dir, image_key + ".npz") if os.path.exists(npz_path): abs_path = npz_path else: @@ -1761,15 +1761,22 @@ def get_optimizer(args, trainable_params): raise ImportError("No dadaptation / dadaptation がインストールされていないようです") print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - min_lr = lr + actual_lr = lr + lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) for group in trainable_params: - min_lr = min(min_lr, group.get("lr", lr)) + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) - if min_lr <= 0.1: + if actual_lr <= 0.1: print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}') + f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}') print('recommend option: lr=1.0 / 推奨は1.0です') + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}") optimizer_class = dadaptation.DAdaptAdam optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) From c78c51c78f26f3bba646c4f86bf769fdfd236fd9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 21:59:25 +0900 Subject: [PATCH 06/12] update documents --- README.md | 25 ++++++++++++++++++++++++- train_README-ja.md | 8 ++++++++ train_network_README-ja.md | 4 ++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index aaf371cb..437b1120 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -- 9 Mar. 2023, 2023/3/9: +- 10 Mar. 2023, 2023/3/10: release v0.5.1 + - Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default). + - Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0. + - Trained models with v0.5.0 will work with Web UI's built-in LoRA and Additional Networks extension. + - Fix an issue that dim (rank) of LoRA module is limited to the in/out dimensions of the target Linear/Conv2d (in case of the dim > 320). + - `resize_lora.py` now have a feature to `dynamic resizing` which means each LoRA module can have different ranks (dims). Thanks to mgz-dev for this great work! + - The appropriate rank is selected based on the complexity of each module with an algorithm specified in the command line arguments. For details: https://github.com/kohya-ss/sd-scripts/pull/243 + - Multiple GPUs training is finally supported in `train_network.py`. Thanks to ddPn08 to solve this long running issue! + - Dataset with fine-tuning method (with metadata json) now works without images if `.npz` files exist. Thanks to rvhfxb! + - `train_network.py` can work if the current directory is not the directory where the script is in. Thanks to mio2333! + - Fix `extract_lora_from_models.py` and `svd_merge_lora.py` doesn't work with higher rank (>320). + + - LoRAのConv2d-3x3拡張を行わない場合(`conv_dim` を指定しない場合)、以前(v0.5.0)と同じ構成になるよう修正しました。 + - ResNetのカーネルサイズ1x1のConv2dが誤って対象になっていました。 + - ただv0.5.0で学習したモデルは Additional Networks 拡張、およびWeb UIのLoRA機能で問題なく使えると思われます。 + - LoRAモジュールの dim (rank) が、対象モジュールの次元数以下に制限される不具合を修正しました(320より大きい dim を指定した場合)。 + - `resize_lora.py` に `dynamic resizing` (リサイズ後の各LoRAモジュールが異なるrank (dim) を持てる機能)を追加しました。mgz-dev 氏の貢献に感謝します。 + - 適切なランクがコマンドライン引数で指定したアルゴリズムにより自動的に選択されます。詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-scripts/pull/243 + - `train_network.py` でマルチGPU学習をサポートしました。長年の懸案を解決された ddPn08 氏に感謝します。 + - fine-tuning方式のデータセット(メタデータ.jsonファイルを使うデータセット)で `.npz` が存在するときには画像がなくても動作するようになりました。rvhfxb 氏に感謝します。 + - 他のディレクトリから `train_network.py` を呼び出しても動作するよう変更しました。 mio2333 氏に感謝します。 + - `extract_lora_from_models.py` および `svd_merge_lora.py` が320より大きいrankを指定すると動かない不具合を修正しました。 + +- 9 Mar. 2023, 2023/3/9: release v0.5.0 - There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while. - Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254 - `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1). diff --git a/train_README-ja.md b/train_README-ja.md index 479f9604..d5f1b5fc 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 +- `--persistent_data_loader_workers` + + Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 + +- `--max_data_loader_n_workers` + + データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 + - `--logging_dir` / `--log_prefix` 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 4a79a6f7..79d1709f 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * `--network_alpha` * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 +* `--persistent_data_loader_workers` + * Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。 +* `--max_data_loader_n_workers` + * データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。 * `--network_weights` * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * `--network_train_unet_only` From b1774608074368b85f9a44659706eff5e9cd52bb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 10 Mar 2023 22:02:17 +0900 Subject: [PATCH 07/12] restore comment --- networks/resize_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index dfacd666..09a19c19 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -1,5 +1,6 @@ # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo import argparse import torch From 8b259297656601e020df1948b6e8507da39dcd71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 11 Mar 2023 08:03:02 +0900 Subject: [PATCH 08/12] fix device error --- README.md | 3 +++ networks/svd_merge_lora.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/README.md b/README.md index 437b1120..07e1cfda 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 11 Mar. 2023, 2023/3/11: + - Fix `svd_merge_lora.py` causes an error about the device. + - `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました。 - 10 Mar. 2023, 2023/3/10: release v0.5.1 - Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default). - Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0. diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 3a03b0d5..8c0d8183 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -77,6 +77,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = (alpha / network_dim) + + if device: # and isinstance(scale, torch.Tensor): + scale = scale.to(device) + up_weight = up_weight.to(device) + down_weight = down_weight.to(device) + if not conv2d: # linear weight = weight + ratio * (up_weight @ down_weight) * scale elif kernel_size == (1, 1): From 0b38e663fd667c05bcdb473c39dd1fb3552b19a1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 11 Mar 2023 08:04:28 +0900 Subject: [PATCH 09/12] remove unnecessary device change --- networks/svd_merge_lora.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 8c0d8183..73228769 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -80,8 +80,6 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) - up_weight = up_weight.to(device) - down_weight = down_weight.to(device) if not conv2d: # linear weight = weight + ratio * (up_weight @ down_weight) * scale From 44d4cfb4539cb567fb3ad909cd752b8fe85029b0 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Sun, 12 Mar 2023 11:52:37 +0700 Subject: [PATCH 10/12] feat: added function to load training config with .toml --- fine_tune.py | 21 +++++++++++++++++++++ finetune/merge_captions_to_metadata.py | 5 ++++- finetune/merge_dd_tags_to_metadata.py | 5 ++++- library/train_util.py | 2 ++ train_db.py | 23 ++++++++++++++++++++++- train_network.py | 23 ++++++++++++++++++++++- train_textual_inversion.py | 21 +++++++++++++++++++++ 7 files changed, 96 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 12557597..20f94cd4 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,6 +5,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -362,4 +363,24 @@ if __name__ == '__main__': parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + train(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index cbc5033f..491e4591 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 4285feb0..8823a9c8 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/library/train_util.py b/library/train_util.py index 718fe36d..248156a3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1598,6 +1598,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') + + parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") if support_dreambooth: # DreamBooth training diff --git a/train_db.py b/train_db.py index a3021177..5fd3c65b 100644 --- a/train_db.py +++ b/train_db.py @@ -7,6 +7,7 @@ import argparse import itertools import math import os +import toml from tqdm import tqdm import torch @@ -361,4 +362,24 @@ if __name__ == '__main__': help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") args = parser.parse_args() - train(args) + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) \ No newline at end of file diff --git a/train_network.py b/train_network.py index 5aa8af48..454bd254 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import os import random import time import json +import toml from tqdm import tqdm import torch @@ -656,4 +657,24 @@ if __name__ == '__main__': help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() - train(args) + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) \ No newline at end of file diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 34b7f092..7cfaedfe 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -3,6 +3,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -523,4 +524,24 @@ if __name__ == '__main__': help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + train(args) From c3f9eb10f169ebf7b55dc6af59f86603fb7f8504 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 18 Mar 2023 18:58:12 +0900 Subject: [PATCH 11/12] format with black --- fine_tune.py | 707 +++--- library/train_util.py | 4281 ++++++++++++++++++++---------------- train_db.py | 713 +++--- train_network.py | 1224 ++++++----- train_textual_inversion.py | 885 ++++---- 5 files changed, 4202 insertions(+), 3608 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 20f94cd4..e3cf247e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,7 +5,7 @@ import argparse import gc import math import os -import toml +import toml from tqdm import tqdm import torch @@ -16,371 +16,416 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) + def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう - # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) - # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか - # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - print("Use xformers by Diffusers") - set_diffusers_xformers_flag(unet, True) - else: - # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") - set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 学習を準備する:モデルを適切な状態にする - training_models = [] - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) - - if args.train_text_encoder: - print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - text_encoder.train() # required for gradient_checkpointing + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - text_encoder.eval() + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - for m in training_models: - m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + for child in module.children(): + fn_recursive_set_mem_eff(child) - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + fn_recursive_set_mem_eff(model) - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - if accelerator.is_main_process: - accelerator.init_trackers("finetuning") + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) for m in training_models: - m.train() + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, + optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, + power=args.lr_scheduler_power, + ) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("finetuning") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + for m in training_models: + m.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + # TODO moving averageにする + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + + args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) else: - target = noise + print(f"{config_path} not found.") - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) - - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch+1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) - print("model saved.") - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - - parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers / Diffusersでxformersを使用する') - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - - args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") - - train(args) + train(args) diff --git a/library/train_util.py b/library/train_util.py index 248156a3..230985ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -32,10 +32,20 @@ from transformers import CLIPTokenizer import transformers import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from diffusers import (StableDiffusionPipeline, DDPMScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler) +from diffusers import ( + StableDiffusionPipeline, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, +) import albumentations as albu import numpy as np from PIL import Image @@ -48,7 +58,7 @@ import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" @@ -64,1086 +74,1241 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux? -class ImageInfo(): - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: Tuple[int, int] = None - self.resized_size: Tuple[int, int] = None - self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_npz_flipped: str = None +class ImageInfo: + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None -class BucketManager(): - def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: - self.no_upscale = no_upscale - if max_reso is None: - self.max_reso = None - self.max_area = None - else: - self.max_reso = max_reso - self.max_area = max_reso[0] * max_reso[1] - self.min_size = min_size - self.max_size = max_size - self.reso_steps = reso_steps - - self.resos = [] - self.reso_to_id = {} - self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key - - def add_image(self, reso, image): - bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image) - - def shuffle(self): - for bucket in self.buckets: - random.shuffle(bucket) - - def sort(self): - # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す - sorted_resos = self.resos.copy() - sorted_resos.sort() - - sorted_buckets = [] - sorted_reso_to_id = {} - for i, reso in enumerate(sorted_resos): - bucket_id = self.reso_to_id[reso] - sorted_buckets.append(self.buckets[bucket_id]) - sorted_reso_to_id[reso] = i - - self.resos = sorted_resos - self.buckets = sorted_buckets - self.reso_to_id = sorted_reso_to_id - - def make_buckets(self): - resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) - self.set_predefined_resos(resos) - - def set_predefined_resos(self, resos): - # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく - self.predefined_resos = resos.copy() - self.predefined_resos_set = set(resos) - self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) - - def add_if_new_reso(self, reso): - if reso not in self.reso_to_id: - bucket_id = len(self.resos) - self.reso_to_id[reso] = bucket_id - self.resos.append(reso) - self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) - - def round_to_steps(self, x): - x = int(x + .5) - return x - x % self.reso_steps - - def select_bucket(self, image_width, image_height): - aspect_ratio = image_width / image_height - if not self.no_upscale: - # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する - reso = (image_width, image_height) - if reso in self.predefined_resos_set: - pass - else: - ar_errors = self.predefined_aspect_ratios - aspect_ratio - predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの - reso = self.predefined_resos[predefined_bucket_id] - - ar_reso = reso[0] / reso[1] - if aspect_ratio > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - # print("use predef", image_width, image_height, reso, resized_size) - else: - if image_width * image_height > self.max_area: - # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める - resized_width = math.sqrt(self.max_area * aspect_ratio) - resized_height = self.max_area / resized_width - assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" - - # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ - # 元のbucketingと同じロジック - b_width_rounded = self.round_to_steps(resized_width) - b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) - ar_width_rounded = b_width_rounded / b_height_in_wr - - b_height_rounded = self.round_to_steps(resized_height) - b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) - ar_height_rounded = b_width_in_hr / b_height_rounded - - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) - - if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): - resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5)) +class BucketManager: + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None else: - resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded) - # print(resized_size) - else: - resized_size = (image_width, image_height) # リサイズは不要 + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps - # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) - bucket_width = resized_size[0] - resized_size[0] % self.reso_steps - bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + self.resos = [] + self.reso_to_id = {} + self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key - reso = (bucket_width, bucket_height) + def add_image(self, reso, image): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image) - self.add_if_new_reso(reso) + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) - ar_error = (reso[0] / reso[1]) - aspect_ratio - return reso, resized_size, ar_error + def sort(self): + # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # print(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + 0.5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) + # print("use predef", image_width, image_height, reso, resized_size) + else: + if image_width * image_height > self.max_area: + # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" + + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) + # print(resized_size) + else: + resized_size = (image_width, image_height) # リサイズは不要 + + # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + + reso = (bucket_width, bucket_height) + + self.add_if_new_reso(reso) + + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error class BucketBatchIndex(NamedTuple): - bucket_index: int - bucket_batch_size: int - batch_index: int + bucket_index: int + bucket_batch_size: int + batch_index: int class AugHelper: - def __init__(self): - # prepare all possible augmentators - color_aug_method = albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), - ], p=.33) - flip_aug_method = albu.HorizontalFlip(p=0.5) + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf( + [ + albu.HueSaturationValue(8, 0, 0, p=0.5), + albu.RandomGamma((95, 105), p=0.5), + ], + p=0.33, + ) + flip_aug_method = albu.HorizontalFlip(p=0.5) - # key: (use_color_aug, use_flip_aug) - self.augmentors = { - (True, True): albu.Compose([ - color_aug_method, - flip_aug_method, - ], p=1.), - (True, False): albu.Compose([ - color_aug_method, - ], p=1.), - (False, True): albu.Compose([ - flip_aug_method, - ], p=1.), - (False, False): None - } + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose( + [ + color_aug_method, + flip_aug_method, + ], + p=1.0, + ), + (True, False): albu.Compose( + [ + color_aug_method, + ], + p=1.0, + ), + (False, True): albu.Compose( + [ + flip_aug_method, + ], + p=1.0, + ), + (False, False): None, + } - def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: - return self.augmentors[(use_color_aug, use_flip_aug)] + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] class BaseSubset: - def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None: - self.image_dir = image_dir - self.num_repeats = num_repeats - self.shuffle_caption = shuffle_caption - self.keep_tokens = keep_tokens - self.color_aug = color_aug - self.flip_aug = flip_aug - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.caption_dropout_rate = caption_dropout_rate - self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs - self.caption_tag_dropout_rate = caption_tag_dropout_rate + def __init__( + self, + image_dir: Optional[str], + num_repeats: int, + shuffle_caption: bool, + keep_tokens: int, + color_aug: bool, + flip_aug: bool, + face_crop_aug_range: Optional[Tuple[float, float]], + random_crop: bool, + caption_dropout_rate: float, + caption_dropout_every_n_epochs: int, + caption_tag_dropout_rate: float, + ) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate - self.img_count = 0 + self.img_count = 0 class DreamBoothSubset(BaseSubset): - def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: - assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + def __init__( + self, + image_dir: str, + is_reg: bool, + class_tokens: Optional[str], + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" - super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) - self.is_reg = is_reg - self.class_tokens = class_tokens - self.caption_extension = caption_extension + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension - def __eq__(self, other) -> bool: - if not isinstance(other, DreamBoothSubset): - return NotImplemented - return self.image_dir == other.image_dir + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir class FineTuningSubset(BaseSubset): - def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: - assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + def __init__( + self, + image_dir, + metadata_file: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" - super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) - self.metadata_file = metadata_file + self.metadata_file = metadata_file - def __eq__(self, other) -> bool: - if not isinstance(other, FineTuningSubset): - return NotImplemented - return self.metadata_file == other.metadata_file + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: - super().__init__() - self.tokenizer = tokenizer - self.max_token_length = max_token_length - # width/height is used when enable_bucket==False - self.width, self.height = (None, None) if resolution is None else resolution - self.debug_dataset = debug_dataset - - self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] - - self.token_padding_disabled = False - self.tag_frequency = {} - - self.enable_bucket = False - self.bucket_manager: BucketManager = None # not initialized - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None - self.bucket_no_upscale = None - self.bucket_info = None # for metadata - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ - - # augmentation - self.aug_helper = AugHelper() - - self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) - - self.image_data: Dict[str, ImageInfo] = {} - self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} - - self.replacements = {} - - def set_current_epoch(self, epoch): - self.current_epoch = epoch - self.shuffle_buckets() - - def set_tag_frequency(self, dir_name, captions): - frequency_for_dir = self.tag_frequency.get(dir_name, {}) - self.tag_frequency[dir_name] = frequency_for_dir - for caption in captions: - for tag in caption.split(","): - tag = tag.strip() - if tag: - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 - - def disable_token_padding(self): - self.token_padding_disabled = True - - def add_replacement(self, str_from, str_to): - self.replacements[str_from] = str_to - - def process_caption(self, subset: BaseSubset, caption): - # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い - is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate - is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 - - if is_drop_out: - caption = "" - else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: - def dropout_tags(tokens): - if subset.caption_tag_dropout_rate <= 0: - return tokens - l = [] - for token in tokens: - if random.random() >= subset.caption_tag_dropout_rate: - l.append(token) - return l - - fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[:subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens:] - - if subset.shuffle_caption: - random.shuffle(flex_tokens) - - flex_tokens = dropout_tags(flex_tokens) - - caption = ", ".join(fixed_tokens + flex_tokens) - - # textual inversion対応 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) - else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) - - return caption - - def get_input_ids(self, caption): - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo, subset: BaseSubset): - self.image_data[info.image_key] = info - self.image_to_subset[info.image_key] = subset - - def make_buckets(self): - ''' - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - ''' - print("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if self.enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketを作成し、画像をbucketに振り分ける - if self.enable_bucket: - if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み - self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height), - self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps) - if not self.bucket_no_upscale: - self.bucket_manager.make_buckets() - else: - print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") - - img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height) - - # print(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) - - self.bucket_manager.sort() - else: - self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) - self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) - - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) - - # bucket情報を表示、格納する - if self.enable_bucket: - self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") - for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): - count = len(bucket) - if count > 0: - self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) - self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") - - # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.bucket_manager.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - self.bucket_manager.shuffle() - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): - image_height, image_width = image.shape[0:2] - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) - image = image[:, p:p + reso[0]] - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) - image = image[p:p + reso[1]] - - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def is_latent_cacheable(self): - return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - - def cache_latents(self, vae): - # TODO ここを高速化したい - print("caching latents.") - for info in tqdm(self.image_data.values()): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) # might be None - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor(info.latents_flipped) - continue - - image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) - - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - if subset.flip_aug: - image = image[:, ::-1].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = self.load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if subset.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if subset.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - if npz_file is None: - return None - return np.load(npz_file)['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] - bucket_batch_size = self.buckets_indices[index].bucket_batch_size - image_index = self.buckets_indices[index].batch_index * bucket_batch_size - - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] - - for image_key in bucket[image_index:image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - - # image/latentsを処理する - if image_info.latents is not None: - latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped - image = None - elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5) - latents = torch.FloatTensor(latents) - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + def __init__( + self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.debug_dataset = debug_dataset + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + + self.token_padding_disabled = False + self.tag_frequency = {} + + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) - if aug is not None: - img = aug(image=img)['image'] + self.aug_helper = AugHelper() - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) - images.append(image) - latents_list.append(latents) + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} - caption = self.process_caption(subset, image_info.caption) - captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - input_ids_list.append(self.get_input_ids(caption)) + self.replacements = {} - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) + def set_current_epoch(self, epoch): + self.current_epoch = epoch + self.shuffle_buckets() - if self.token_padding_disabled: - # padding=True means pad in the batch - example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # batch processing seems to be good - example['input_ids'] = torch.stack(input_ids_list) + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images + def disable_token_padding(self): + self.token_padding_disabled = True - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to - if self.debug_dataset: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example + def process_caption(self, subset: BaseSubset, caption): + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = ( + is_drop_out + or subset.caption_dropout_every_n_epochs > 0 + and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + ) + + if is_drop_out: + caption = "" + else: + if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + fixed_tokens = [] + flex_tokens = [t.strip() for t in caption.strip().split(",")] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = flex_tokens[subset.keep_tokens :] + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer( + caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" + ).input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + """ + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + """ + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketを作成し、画像をbucketに振り分ける + if self.enable_bucket: + if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み + self.bucket_manager = BucketManager( + self.bucket_no_upscale, + (self.width, self.height), + self.min_bucket_reso, + self.max_bucket_reso, + self.bucket_reso_steps, + ) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" + ) + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) + + # print(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # bucket情報を表示、格納する + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + + # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる + self.buckets_indices: List(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す + #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる + # + # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは + # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう + # # そのためバッチサイズを画像種類までに制限する + # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? + # # TO DO 正則化画像をepochまたがりで利用する仕組み + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + # ↑ここまで + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + image_height, image_width = image.shape[0:2] + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def cache_latents(self, vae): + # TODO ここを高速化したい + print("caching latents.") + for info in tqdm(self.image_data.values()): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if subset.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + 0.5) + nw = int(width * scale + 0.5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + 0.5) + face_cy = int(face_cy * scale + 0.5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if subset.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1 : p1 + target_size, :] + else: + image = image[:, p1 : p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)["arr_0"] + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + example["captions"] = captions + return example class DreamBoothDataset(BaseDataset): - def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__( + self, + subsets: Sequence[DreamBoothSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + prior_loss_weight: float, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.latents_cache = None + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None - self.enable_bucket = enable_bucket - if self.enable_bucket: - assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None # この情報は使われない - self.bucket_no_upscale = False - - def read_caption(img_path, caption_extension): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(subset: DreamBoothSubset): - if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") - return [], [] - - img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") - captions.append("") + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale else: - captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False - self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + def read_caption(img_path, caption_extension): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - return img_paths, captions + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption - print("prepare images.") - num_train_images = 0 - num_reg_images = 0 - reg_infos: List[ImageInfo] = [] - for subset in subsets: - if subset.num_repeats < 1: - print( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") - continue + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] - if subset in self.subsets: - print( - f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") - continue + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - img_paths, captions = load_dreambooth_dir(subset) - if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") - continue + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) - if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) - else: - num_train_images += subset.num_repeats * len(img_paths) + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) - if subset.is_reg: - reg_infos.append(info) + return img_paths, captions + + print("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") else: - self.register_image(info, subset) + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False - subset.img_count = len(img_paths) - self.subsets.append(subset) - - print(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images - - print(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False - - self.num_reg_images = num_reg_images + self.num_reg_images = num_reg_images class FineTuningDataset(BaseDataset): - def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__( + self, + subsets: Sequence[FineTuningSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - self.batch_size = batch_size + self.batch_size = batch_size - self.num_train_images = 0 - self.num_reg_images = 0 + self.num_train_images = 0 + self.num_reg_images = 0 - for subset in subsets: - if subset.num_repeats < 1: - print( - f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") - continue + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue - if subset in self.subsets: - print( - f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") - continue + if subset in self.subsets: + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue - # メタデータを読み込む - if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") - with open(subset.metadata_file, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") + # メタデータを読み込む + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") - if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") - continue + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + continue - tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] + + caption = img_md.get("caption") + tags = img_md.get("tags") + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ", " + tags + tags_list.append(tags) + + if caption is None: + caption = "" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get("train_resolution") + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) + + # check existence of all npz files + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True + + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + use_npz_latents = False + print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + elif not npz_all: + use_npz_latents = False + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + if flip_aug_in_subset: + print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + # else: + # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + if use_npz_latents: + use_npz_latents = False + print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + + assert ( + resolution is not None + ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path - else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + print("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - tags_list.append(tags) + assert ( + not bucket_no_upscale + ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - if caption is None: - caption = "" + # bucket情報を初期化しておく、make_bucketsで再作成しない + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) - image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) - image_info.image_size = img_md.get('train_resolution') + # npz情報をきれいにしておく + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None - if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + ".npz" - self.register_image(image_info, subset) + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + "_flip.npz" + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip - self.num_train_images += len(metadata) * subset.num_repeats + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") + npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") - # TODO do not record tag freq when no tag - self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) - subset.img_count = len(metadata) - self.subsets.append(subset) + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None - # check existence of all npz files - use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) - if use_npz_latents: - flip_aug_in_subset = False - npz_any = False - npz_all = True - - for image_info in self.image_data.values(): - subset = self.image_to_subset[image_info.image_key] - - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if subset.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - flip_aug_in_subset = True - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") - elif not npz_all: - use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") - # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") - - # check min/max bucket size - sizes = set() - resos = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - resos.add(tuple(image_info.image_size)) - - if sizes is None: - if use_npz_latents: - use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") - - assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" - - self.enable_bucket = enable_bucket - if self.enable_bucket: - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") - self.enable_bucket = True - - assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - - # bucket情報を初期化しておく、make_bucketsで再作成しない - self.bucket_manager = BucketManager(False, None, None, None, None) - self.bucket_manager.set_predefined_resos(resos) - - # npz情報をきれいにしておく - if not use_npz_latents: - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + '.npz' - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + '_flip.npz' - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # image_key is relative path - npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz') - npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz') - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip + return npz_file_norm, npz_file_flip # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): - def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): - self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] - super().__init__(datasets) + super().__init__(datasets) - self.image_data = {} - self.num_train_images = 0 - self.num_reg_images = 0 + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 - # simply concat together - # TODO: handling image_data key duplication among dataset - # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. - for dataset in datasets: - self.image_data.update(dataset.image_data) - self.num_train_images += dataset.num_train_images - self.num_reg_images += dataset.num_reg_images + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images - def add_replacement(self, str_from, str_to): - for dataset in self.datasets: - dataset.add_replacement(str_from, str_to) + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) - # def make_buckets(self): - # for dataset in self.datasets: - # dataset.make_buckets() + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() - def cache_latents(self, vae): - for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") - dataset.cache_latents(vae) + def cache_latents(self, vae): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_latents(vae) - def is_latent_cacheable(self) -> bool: - return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) - def set_current_epoch(self, epoch): - for dataset in self.datasets: - dataset.set_current_epoch(epoch) + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) - def disable_token_padding(self): - for dataset in self.datasets: - dataset.disable_token_padding() + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") - train_dataset.set_current_epoch(1) - k = 0 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - for i, idx in enumerate(indices): - example = train_dataset[idx] - if example['latents'] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') - if show_input_ids: - print(f"input ids: {iid}") - if example['images'] is not None: - im = example['images'][j] - print(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - if os.name == 'nt': # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or (example['images'] is None and i >= 8): - break + train_dataset.set_current_epoch(1) + k = 0 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + for i, idx in enumerate(indices): + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or (example["images"] is None and i >= 8): + break def glob_images(directory, base="*"): - img_paths = [] - for ext in IMAGE_EXTENSIONS: - if base == '*': - img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) - else: - img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() - return img_paths + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + # img_paths = list(set(img_paths)) # 重複を排除 + # img_paths.sort() + return img_paths def glob_images_pathlib(dir_path, recursive): - image_paths = [] - if recursive: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.rglob('*' + ext)) - else: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.glob('*' + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() - return image_paths + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob("*" + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob("*" + ext)) + # image_paths = list(set(image_paths)) # 重複を排除 + # image_paths.sort() + return image_paths + # endregion @@ -1165,86 +1330,86 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d + return val if exists(val) else d def model_hash(filename): - """Old model hash used by stable-diffusion-webui""" - try: - with open(filename, "rb") as file: - m = hashlib.sha256() + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return 'NOFILE' + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" def calculate_sha256(filename): - """New model hash used by stable-diffusion-webui""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(blksize), b""): - hash_sha256.update(chunk) + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() def precalculate_safetensors_hashes(tensors, metadata): - """Precalculate the model hashes needed by sd-webui-additional-networks to - save time on indexing the model later.""" + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" - # Because writing user metadata to the file can change the result of - # sd_models.model_hash(), only retain the training metadata for purposes of - # calculating the hash, as they are meant to be immutable - metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} - bytes = safetensors.torch.save(tensors, metadata) - b = BytesIO(bytes) + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) - model_hash = addnet_hash_safetensors(b) - legacy_hash = addnet_hash_legacy(b) - return model_hash, legacy_hash + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash def addnet_hash_legacy(b): - """Old model hash used by sd-webui-additional-networks for .safetensors format files""" - m = hashlib.sha256() + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() - b.seek(0x100000) - m.update(b.read(0x10000)) - return m.hexdigest()[0:8] + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] def addnet_hash_safetensors(b): - """New model hash used by sd-webui-additional-networks for .safetensors format files""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, "little") + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") - offset = n + 8 - b.seek(offset) - for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() def get_git_revision_hash() -> str: - try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip() - except: - return "(unknown)" + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" # flash attention forwards and backwards @@ -1253,424 +1418,603 @@ def get_git_revision_hash() -> str: class FlashAttentionFunction(torch.autograd.function.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = (q.shape[-1] ** -0.5) + scale = q.shape[-1] ** -0.5 - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention (not xformers)") - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion # region arguments + def add_sd_models_arguments(parser: argparse.ArgumentParser): - # for pretrained models - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--tokenizer_cache_dir", type=str, default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") + # for pretrained models + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + ) - # backward compatibility - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--use_lion_optimizer", action="store_true", - help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") + # backward compatibility + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)", + ) + parser.add_argument( + "--use_lion_optimizer", + action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)", + ) - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_grad_norm", default=1.0, type=float, - help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない") + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + ) - parser.add_argument("--optimizer_args", type=str, default=None, nargs='*', - help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', + ) - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") - parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, - help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") - parser.add_argument("--lr_scheduler_power", type=float, default=1, - help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, - help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_n_epoch_ratio", type=int, default=None, - help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") - parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, - help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving / 保存時に精度を変更して保存する", + ) + parser.add_argument( + "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + ) + parser.add_argument( + "--save_n_epoch_ratio", + type=int, + default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", + ) + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument( + "--max_token_length", + type=int, + default=None, + choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)", + ) + parser.add_argument( + "--mem_eff_attn", + action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", + ) + parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--max_train_epochs", type=int, default=None, - help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") - parser.add_argument("--max_data_loader_n_workers", type=int, default=8, - help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") - parser.add_argument("--persistent_data_loader_workers", action="store_true", - help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--noise_offset", type=float, default=None, - help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") - parser.add_argument("--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + ) + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)", + ) + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", + ) + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--noise_offset", + type=float, + default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", + ) + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) - parser.add_argument("--sample_every_n_steps", type=int, default=None, - help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する") - parser.add_argument("--sample_every_n_epochs", type=int, default=None, - help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)") - parser.add_argument("--sample_prompts", type=str, default=None, - help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル") - parser.add_argument('--sample_sampler', type=str, default='ddim', - choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', - 'dpmsolver++', 'dpmsingle', - 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], - help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') - - parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") + parser.add_argument( + "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", + ) + parser.add_argument( + "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + ) + parser.add_argument( + "--sample_sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", + ) - if support_dreambooth: - # DreamBooth training - parser.add_argument("--prior_loss_weight", type=float, default=1.0, - help="loss weight for regularization images / 正則化画像のlossの重み") + parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") + + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" + ) def verify_training_args(args: argparse.Namespace): - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): - # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--keep_tokens", type=int, default=0, - help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument("--bucket_reso_steps", type=int, default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") - parser.add_argument("--bucket_no_upscale", action="store_true", - help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") +def add_dataset_arguments( + parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool +): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)", + ) + parser.add_argument( + "--keep_tokens", + type=int, + default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", + ) + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--face_crop_aug_range", + type=str, + default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)", + ) + parser.add_argument( + "--random_crop", + action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", + ) + parser.add_argument( + "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + ) + parser.add_argument( + "--resolution", + type=str, + default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", + ) + parser.add_argument( + "--cache_latents", + action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + ) - if support_caption_dropout: - # Textual Inversion はcaptionのdropoutをsupportしない - # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに - parser.add_argument("--caption_dropout_rate", type=float, default=0.0, - help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") - parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0, - help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") - parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0, - help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument( + "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" + ) + parser.add_argument( + "--caption_dropout_every_n_epochs", + type=int, + default=0, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする", + ) + parser.add_argument( + "--caption_tag_dropout_rate", + type=float, + default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", + ) - if support_dreambooth: - # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - if support_caption: - # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") - parser.add_argument("--dataset_repeats", type=int, default=1, - help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument( + "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + ) def add_sd_saving_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + parser.add_argument( + "--save_model_as", + type=str, + default=None, + choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", + ) + # endregion @@ -1678,178 +2022,190 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" - optimizer_type = args.optimizer_type - if args.use_8bit_adam: - assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" - assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "AdamW8bit" + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert ( + not args.use_lion_optimizer + ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "AdamW8bit" - elif args.use_lion_optimizer: - assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "Lion" + elif args.use_lion_optimizer: + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "Lion" - if optimizer_type is None or optimizer_type == "": - optimizer_type = "AdamW" - optimizer_type = optimizer_type.lower() + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() - # 引数を分解する:boolとfloat、tupleのみ対応 - optimizer_kwargs = {} - if args.optimizer_args is not None and len(args.optimizer_args) > 0: - for arg in args.optimizer_args: - key, value = arg.split('=') + # 引数を分解する:boolとfloat、tupleのみ対応 + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") - value = value.split(",") - for i in range(len(value)): - if value[i].lower() == "true" or value[i].lower() == "false": - value[i] = (value[i].lower() == "true") + value = value.split(",") + for i in range(len(value)): + if value[i].lower() == "true" or value[i].lower() == "false": + value[i] = value[i].lower() == "true" + else: + value[i] = float(value[i]) + if len(value) == 1: + value = value[0] + else: + value = tuple(value) + + optimizer_kwargs[key] = value + # print("optkwargs:", optimizer_kwargs) + + lr = args.learning_rate + + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "DAdaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 引数を確認して適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + # trainable_paramsがgroupだった時の処理:lrを削除する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない + print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None else: - value[i] = float(value[i]) - if len(value) == 1: - value = value[0] - else: - value = tuple(value) + if args.max_grad_norm != 0.0: + print( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" + ) + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") - optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - lr = args.learning_rate + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - if optimizer_type == "AdamW8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します") - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "Lion".lower(): - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") - optimizer_class = lion_pytorch.Lion - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = torch.optim.SGD - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "DAdaptation".lower(): - try: - import dadaptation - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}') - print('recommend option: lr=1.0 / 推奨は1.0です') - if lr_count > 1: - print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}") - - optimizer_class = dadaptation.DAdaptAdam - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Adafactor".lower(): - # 引数を確認して適宜補正する - if "relative_step" not in optimizer_kwargs: - optimizer_kwargs["relative_step"] = True # default - if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") - optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") - - if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") - if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") - args.learning_rate = None - - # trainable_paramsがgroupだった時の処理:lrを削除する - if type(trainable_params) == list and type(trainable_params[0]) == dict: - has_group_lr = False - for group in trainable_params: - p = group.pop("lr", None) - has_group_lr = has_group_lr or (p is not None) - - if has_group_lr: - # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") - args.unet_lr = None - args.text_encoder_lr = None - - if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") - args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど - - lr = None else: - if args.max_grad_norm != 0.0: - print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません") - if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") - if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + # 任意のoptimizerを使う + optimizer_type = args.optimizer_type # lowerでないやつ(微妙) + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] - optimizer_class = transformers.optimization.Adafactor - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") - optimizer_class = torch.optim.AdamW - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - else: - # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] - - optimizer_class = getattr(optimizer_module, optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ - optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - - return optimizer_name, optimizer_args, optimizer + return optimizer_name, optimizer_args, optimizer # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler @@ -1866,532 +2222,589 @@ def get_scheduler_fix( num_cycles: int = 1, power: float = 1.0, ): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - if name.startswith("adafactor"): - assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" - initial_lr = float(name.split(':')[1]) - # print("adafactor scheduler init lr", initial_lr) - return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(":")[1]) + # print("adafactor scheduler init lr", initial_lr) + return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) + if name == SchedulerType.POLYNOMIAL: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): - # backward compatibility - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None - # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - if args.resolution is not None: - args.resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(args.resolution) == 1: - args.resolution = (args.resolution[0], args.resolution[0]) - assert len(args.resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(",")]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert ( + len(args.resolution) == 2 + ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - if args.face_crop_aug_range is not None: - args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \ - f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - args.face_crop_aug_range = None + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) + assert ( + len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] + ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None - if support_metadata: - if args.in_json is not None and (args.color_aug or args.random_crop): - print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます") + if support_metadata: + if args.in_json is not None and (args.color_aug or args.random_crop): + print( + f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" + ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH + print("prepare tokenizer") + original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_')) - if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) + if tokenizer is None: + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(original_path) - if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) - return tokenizer + return tokenizer def prepare_accelerator(args: argparse.Namespace): - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + logging_dir=logging_dir, + ) - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) - return accelerator, unwrap_model + return accelerator, unwrap_model def prepare_dtype(args: argparse.Namespace): - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 - return weight_dtype, save_dtype + return weight_dtype, save_dtype def load_target_model(args: argparse.Namespace, weight_dtype): - name_or_path = args.pretrained_model_name_or_path - name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path - load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) - else: - print("load Diffusers pretrained models") - try: - pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) - except EnvironmentError as ex: - print( - f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}") - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) + else: + print("load Diffusers pretrained models") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") - return text_encoder, vae, unet, load_stable_diffusion_format + return text_encoder, vae, unet, load_stable_diffusion_format def patch_accelerator_for_fp16_training(accelerator): - org_unscale_grads = accelerator.scaler._unscale_grads_ + org_unscale_grads = accelerator.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): - # with no_token_padding, the length is not max length, return result immediately - if input_ids.size()[-1] != tokenizer.model_max_length: - return text_encoder(input_ids)[0] + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] - b_size = input_ids.size()[0] - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - return encoder_hidden_states + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + return encoder_hidden_states def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): - model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name - ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") - return model_name, ckpt_name + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): - saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - if saving: - os.makedirs(args.output_dir, exist_ok=True) - save_func() + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() - if args.save_last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs - remove_old_func(remove_epoch_no) - return saving + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving -def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): - epoch_no = epoch + 1 - model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) +def save_sd_model_on_epoch_end( + args: argparse.Namespace, + accelerator, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder, + unet, + vae, +): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) - if save_stable_diffusion_format: - def save_sd(): - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_path, epoch_no, global_step, save_dtype, vae) + if save_stable_diffusion_format: - def remove_sd(old_epoch_no): - _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) - save_func = save_sd - remove_old_func = remove_sd - else: - def save_du(): - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) - print(f"saving model: {out_dir}") - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_path, vae=vae, use_safetensors=use_safetensors) + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) - def remove_du(old_epoch_no): - out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) - if os.path.exists(out_dir_old): - print(f"removing old model: {out_dir_old}") - shutil.rmtree(out_dir_old) + save_func = save_sd + remove_old_func = remove_sd + else: - save_func = save_du - remove_old_func = remove_du + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) - saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) - if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs - if last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") - shutil.rmtree(state_dir_old) + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) -def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder, + unet, + vae, +): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) - ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") - ckpt_file = os.path.join(args.output_dir, ckpt_name) + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_path, epoch, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, model_name) - os.makedirs(out_dir, exist_ok=True) + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae + ) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_path, vae=vae, use_safetensors=use_safetensors) + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) def save_state_on_train_end(args: argparse.Namespace, accelerator): - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = 'scaled_linear' +SCHEDLER_SCHEDULE = "scaled_linear" -def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None): - """ - 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない - clip skipは対応した - """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: - return - else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch - return +def sample_images( + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None +): + """ + 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない + clip skipは対応した + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return - print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") - return + print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return - org_vae_device = vae.device # CPUにいるはず - vae.to(device) + org_vae_device = vae.device # CPUにいるはず + vae.to(device) - # clip skip 対応のための wrapper を作る - if args.clip_skip is None: - text_encoder_or_wrapper = text_encoder - else: - class Wrapper(): - def __init__(self, tenc) -> None: - self.tenc = tenc - self.config = {} - super().__init__() + # clip skip 対応のための wrapper を作る + if args.clip_skip is None: + text_encoder_or_wrapper = text_encoder + else: - def __call__(self, input_ids, attention_mask): - enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) - pooled_output = enc_out['pooler_output'] - return encoder_hidden_states, pooled_output # 1st output is only used + class Wrapper: + def __init__(self, tenc) -> None: + self.tenc = tenc + self.config = {} + super().__init__() - text_encoder_or_wrapper = Wrapper(text_encoder) + def __call__(self, input_ids, attention_mask): + enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) + pooled_output = enc_out["pooler_output"] + return encoder_hidden_states, pooled_output # 1st output is only used - # read prompts - with open(args.sample_prompts, 'rt', encoding='utf-8') as f: - prompts = f.readlines() + text_encoder_or_wrapper = Wrapper(text_encoder) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms': - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler': - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a': - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args['algorithm_type'] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2': - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a': - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler + # read prompts + with open(args.sample_prompts, "rt", encoding="utf-8") as f: + prompts = f.readlines() - if args.v_parameterization: - sched_init_args['prediction_type'] = 'v_prediction' + # schedulerを用意する + sched_init_args = {} + if args.sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif args.sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sample_sampler + elif args.sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler - scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) - pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) - pipeline.to(device) + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + scheduler.config.clip_sample = True - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder_or_wrapper, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + pipeline.to(device) - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) - with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue - prompt = prompt.strip() - if len(prompt) == 0 or prompt[0] == '#': - continue + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() - # subset of gen_img_diffusers - prompt_args = prompt.split(' --') - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r'w (\d+)', parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == "#": + continue - m = re.match(r'h (\d+)', parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue - m = re.match(r'd (\d+)', parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue - m = re.match(r's (\d+)', parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue - m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue - m = re.match(r'n (.+)', parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] - image.save(os.path.join(save_dir, img_filename)) + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() + image.save(os.path.join(save_dir, img_filename)) + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) - torch.set_rng_state(rng_state) - torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) # endregion @@ -2399,24 +2812,24 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v class ImageLoadingDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = self.images[idx] - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor_pil = transforms.functional.pil_to_tensor(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor_pil, img_path) + return (tensor_pil, img_path) # endregion diff --git a/train_db.py b/train_db.py index 5fd3c65b..2ad9c69c 100644 --- a/train_db.py +++ b/train_db.py @@ -18,368 +18,417 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, False) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - if args.no_token_padding: - train_dataset_group.disable_token_padding() + if args.no_token_padding: + train_dataset_group.disable_token_padding() - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # acceleratorを準備する - print("prepare accelerator") + # acceleratorを準備する + print("prepare accelerator") - if args.gradient_accumulation_steps > 1: - print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") - print( - f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") + if args.gradient_accumulation_steps > 1: + print( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" + ) - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 学習を準備する:モデルを適切な状態にする - train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 - unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(train_text_encoder) - if not train_text_encoder: - print("Text Encoder is not trained.") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - if train_text_encoder: - trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) - else: - trainable_params = unet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth") - - loss_list = [] - loss_total = 0.0 - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - unet.train() - # train==True is required to enable gradient_checkpointing - if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: - text_encoder.train() - - for step, batch in enumerate(train_dataloader): - # 指定したステップ数でText Encoderの学習を止める - if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") - if not args.gradient_checkpointing: - text_encoder.train(False) - text_encoder.requires_grad_(False) - - with accelerator.accumulate(unet): + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + # 学習を準備する:モデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") - # Get the text embedding for conditioning - with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + if train_text_encoder: + trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + trainable_params = unet.parameters() - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + _, _, optimizer = train_util.get_optimizer(args, trainable_params) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, + optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_scheduler_num_cycles, + power=args.lr_scheduler_power, + ) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth") + + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + # 指定したステップ数でText Encoderの学習を止める + if global_step == args.stop_text_encoder_training: + print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) + + with accelerator.accumulate(unet): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", + ) + + args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) else: - target = noise + print(f"{config_path} not found.") - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - if train_text_encoder: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) - else: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) - - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch+1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) - print("model saved.") - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False, True) - train_util.add_training_arguments(parser, True) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - - parser.add_argument("--no_token_padding", action="store_true", - help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") - parser.add_argument("--stop_text_encoder_training", type=int, default=None, - help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") - - args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") - - train(args) \ No newline at end of file + train(args) diff --git a/train_network.py b/train_network.py index 454bd254..f78d8e47 100644 --- a/train_network.py +++ b/train_network.py @@ -26,655 +26,693 @@ from library.config_util import ( def collate_fn(examples): - return examples[0] + return examples[0] # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = {"loss/current": current_loss, "loss/average": avr_loss} + logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.network_train_unet_only: - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - else: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + if args.network_train_unet_only: + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + else: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - return logs + return logs def train(args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None - if args.seed is not None: - set_seed(args.seed) + if args.seed is not None: + set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - print("Train with captions.") - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable( - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - - # work on low-ram device - if args.lowram: - text_encoder.to("cuda") - unet.to("cuda") - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # prepare network - import sys - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - - net_kwargs = {} - if args.network_args is not None: - for net_arg in args.network_args: - key, value = net_arg.split('=') - net_kwargs[key] = value - - # if a new network is added in future, add if ~ then blocks for each network (;'∀') - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) - if network is None: - return - - if args.network_weights is not None: - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) - - train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only - network.apply_to(text_encoder, unet, train_text_encoder, train_unet) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) - if is_main_process: - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - network.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_unet and train_text_encoder: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler) - elif train_text_encoder: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler) - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - text_encoder.train() - - # set top parameter requires_grad = True for gradient checkpointing works - if type(text_encoder) == DDP: - text_encoder.module.text_model.embeddings.requires_grad_(True) - else: - text_encoder.text_model.embeddings.requires_grad_(True) - else: - unet.eval() - text_encoder.eval() - - # support DistributedDataParallel - if type(text_encoder) == DDP: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - - network.prepare_grad_etc(text_encoder, unet) - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - if is_main_process: - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - # TODO refactor metadata creation and move to util - metadata = { - "ss_session_id": session_id, # random integer indicating which group of epochs the model came from - "ss_training_started_at": training_started_at, # unix timestamp - "ss_output_name": args.output_name, - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset_group.num_train_images, - "ss_num_reg_images": train_dataset_group.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_gradient_checkpointing": args.gradient_checkpointing, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not use this value - "ss_mixed_precision": args.mixed_precision, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_cache_latents": bool(args.cache_latents), - "ss_seed": args.seed, - "ss_lowram": args.lowram, - "ss_noise_offset": args.noise_offset, - "ss_training_comment": args.training_comment, # will not be updated after training - "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), - "ss_max_grad_norm": args.max_grad_norm, - "ss_caption_dropout_rate": args.caption_dropout_rate, - "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, - "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, - "ss_face_crop_aug_range": args.face_crop_aug_range, - "ss_prior_loss_weight": args.prior_loss_weight, - } - - if use_user_config: - # save metadata of multiple datasets - # NOTE: pack "ss_datasets" value as json one time - # or should also pack nested collections as json? - datasets_metadata = [] - tag_frequency = {} # merge tag frequency for metadata editor - dataset_dirs_info = {} # merge subset dirs for metadata editor - - for dataset in train_dataset_group.datasets: - is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) - dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, - } - - subsets_metadata = [] - for subset in dataset.subsets: - subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, - } - - image_dir_or_metadata_file = None - if subset.image_dir: - image_dir = os.path.basename(subset.image_dir) - subset_metadata["image_dir"] = image_dir - image_dir_or_metadata_file = image_dir - - if is_dreambooth_dataset: - subset_metadata["class_tokens"] = subset.class_tokens - subset_metadata["is_reg"] = subset.is_reg - if subset.is_reg: - image_dir_or_metadata_file = None # not merging reg dataset + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } else: - metadata_file = os.path.basename(subset.metadata_file) - subset_metadata["metadata_file"] = metadata_file - image_dir_or_metadata_file = metadata_file # may overwrite + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - subsets_metadata.append(subset_metadata) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - # merge dataset dir: not reg subset only - # TODO update additional-network extension to show detailed dataset config from metadata - if image_dir_or_metadata_file is not None: - # datasets may have a certain dir multiple times - v = image_dir_or_metadata_file - i = 2 - while v in dataset_dirs_info: - v = image_dir_or_metadata_file + f" ({i})" - i += 1 - image_dir_or_metadata_file = v + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return - dataset_dirs_info[image_dir_or_metadata_file] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - dataset_metadata["subsets"] = subsets_metadata - datasets_metadata.append(dataset_metadata) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process - # merge tag frequency: - for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): - # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える - # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない - # なので、ここで複数datasetの回数を合算してもあまり意味はない - if ds_dir_name in tag_frequency: - continue - tag_frequency[ds_dir_name] = ds_freq_for_dir + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - metadata["ss_datasets"] = json.dumps(datasets_metadata) - metadata["ss_tag_frequency"] = json.dumps(tag_frequency) - metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) - else: - # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert len( - train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - dataset = train_dataset_group.datasets[0] + # work on low-ram device + if args.lowram: + text_encoder.to("cuda") + unet.to("cuda") - dataset_dirs_info = {} - reg_dataset_dirs_info = {} - if use_dreambooth_method: - for subset in dataset.subsets: - info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info - info[os.path.basename(subset.image_dir)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } - else: - for subset in dataset.subsets: - dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - metadata.update({ - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), - }) - - # add extra args - if args.network_args: - metadata["ss_network_args"] = json.dumps(net_kwargs) - # for key, value in net_kwargs.items(): - # metadata["ss_arg_" + key] = value - - # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path - if os.path.exists(sd_model_name): - metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) - metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) - sd_model_name = os.path.basename(sd_model_name) - metadata["ss_sd_model_name"] = sd_model_name - - if args.vae is not None: - vae_name = args.vae - if os.path.exists(vae_name): - metadata["ss_vae_hash"] = train_util.model_hash(vae_name) - metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) - vae_name = os.path.basename(vae_name) - metadata["ss_vae_name"] = vae_name - - metadata = {k: str(v) for k, v in metadata.items()} - - # make minimum metadata for filtering - minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] - minimum_metadata = {} - for key in minimum_keys: - if key in metadata: - minimum_metadata[key] = metadata[key] - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("network_train") - - loss_list = [] - loss_total = 0.0 - for epoch in range(num_train_epochs): - if is_main_process: - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - metadata["ss_epoch"] = str(epoch+1) - - network.on_epoch_start(text_encoder, unet) - - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(network): + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - with torch.set_grad_enabled(train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + # prepare network + import sys - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if network is None: + return - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, + optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, + power=args.lr_scheduler_power, + ) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) else: - target = noise + text_encoder.text_model.embeddings.requires_grad_(True) + else: + unet.eval() + text_encoder.eval() - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + network.prepare_grad_etc(text_encoder, unet) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + } - if global_step >= args.max_train_steps: - break + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch+1) + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } - accelerator.wait_for_everyone() + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } - if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) + print(f"saving checkpoint: {ckpt_file}") + unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) + + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) - metadata["ss_training_finished_at"] = str(time.time()) - print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - if is_main_process: - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # end of epoch - - metadata["ss_epoch"] = str(num_train_epochs) - metadata["ss_training_finished_at"] = str(time.time()) - - if is_main_process: - network = unwrap_model(network) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + '.' + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - print("model saved.") + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + print("model saved.") -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, True) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") - parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) - parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, - help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') - parser.add_argument("--network_dim", type=int, default=None, - help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') - parser.add_argument("--network_alpha", type=float, default=1, - help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') - parser.add_argument("--network_args", type=str, default=None, nargs='*', - help='additional argmuments for network (key=value) / ネットワークへの追加の引数') - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") - parser.add_argument("--network_train_text_encoder_only", action="store_true", - help="only training Text Encoder part / Text Encoder関連部分のみ学習する") - parser.add_argument("--training_comment", type=str, default=None, - help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") + parser.add_argument( + "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") + parser.add_argument( + "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + ) + parser.add_argument( + "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + ) - args = parser.parse_args() + args = parser.parse_args() - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") - - train(args) \ No newline at end of file + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7cfaedfe..0e9fba76 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -14,8 +14,8 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) imagenet_templates_small = [ @@ -72,476 +72,525 @@ imagenet_style_templates_small = [ def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) + if args.seed is not None: + set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") - - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" + ) else: - print("Train with captions.") - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } + init_token_ids = None - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - print("use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler) - - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - text_encoder.train() - - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # weight_dtype) use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } else: - target = noise + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # Let's make sure we don't update any embedding weights besides the newly added token + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, - vae, tokenizer, text_encoder, unet, prompt_replacement) + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - loss_total += current_loss - avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - if global_step >= args.max_train_steps: - break + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args.lr_scheduler, + optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_scheduler_num_cycles, + power=args.lr_scheduler_power, + ) - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch+1) + # acceleratorがなんかよろしくやってくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) - accelerator.wait_for_everyone() + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + text_encoder.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") + + print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) - - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, - vae, tokenizer, text_encoder, unet, prompt_replacement) - - # end of epoch - - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + '.' + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - print("model saved.") + print("model saved.") def save_weights(file, updated_embs, save_dtype): - state_dict = {"emb_params": updated_embs} + state_dict = {"emb_params": updated_embs} - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import save_file - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI def load_weights(file): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location='cpu') - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - if 'string_to_param' in data: # textual inversion embeddings - data = data['string_to_param'] - if hasattr(data, '_parameters'): # support old PyTorch? - data = getattr(data, '_parameters') + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - return emb + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) + + return emb -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, False) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", + ) - parser.add_argument("--weights", type=str, default=None, - help="embedding weights to initialize / 学習するネットワークの初期重み") - parser.add_argument("--num_vectors_per_token", type=int, default=1, - help='number of vectors per token / トークンに割り当てるembeddingsの要素数') - parser.add_argument("--token_string", type=str, default=None, - help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") - parser.add_argument("--init_word", type=str, default=None, - help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") - parser.add_argument("--use_object_template", action='store_true', - help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") - parser.add_argument("--use_style_template", action='store_true', - help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", + ) - args = parser.parse_args() + args = parser.parse_args() - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") - - train(args) + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) From 83e102c6912954b6cd76ca7a82d5afb7fb3043bf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Mar 2023 10:11:11 +0900 Subject: [PATCH 12/12] refactor config parse, feature to output config --- fine_tune.py | 20 +---------- library/train_util.py | 72 +++++++++++++++++++++++++++++++++++++- train_db.py | 20 +---------- train_network.py | 20 +---------- train_textual_inversion.py | 20 +---------- 5 files changed, 75 insertions(+), 77 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index e3cf247e..2b5255dc 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -408,24 +408,6 @@ if __name__ == "__main__": parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/library/train_util.py b/library/train_util.py index 230985ef..9f541b6c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import importlib import json +import pathlib import re import shutil import time @@ -23,6 +24,7 @@ import random import hashlib import subprocess from io import BytesIO +import toml from tqdm import tqdm import torch @@ -1889,7 +1891,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", ) - parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" + ) if support_dreambooth: # DreamBooth training @@ -2016,6 +2026,66 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): ) +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config"]: + if key in args_dict: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + exit(0) + + if not os.path.exists(config_path): + print(f"{config_path} not found.") + exit(1) + + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + print(args.config_file) + + return args + + # endregion # region utils diff --git a/train_db.py b/train_db.py index 2ad9c69c..c812bbc7 100644 --- a/train_db.py +++ b/train_db.py @@ -411,24 +411,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_network.py b/train_network.py index f78d8e47..ca0da112 100644 --- a/train_network.py +++ b/train_network.py @@ -695,24 +695,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e9fba76..f591dea1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -573,24 +573,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - - if args.config_file: - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - if os.path.exists(config_path): - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = args.config_file.split(".")[0] - print(args.config_file) - else: - print(f"{config_path} not found.") + args = train_util.read_config_from_file(args, parser) train(args)