From 5a1a14f9fc096ad5c2c7cdfa15d0e23320f69ab1 Mon Sep 17 00:00:00 2001 From: TingTingin <36141041+TingTingin@users.noreply.github.com> Date: Tue, 23 May 2023 01:57:35 -0400 Subject: [PATCH 01/11] Update train_util.py Added feature to add "." if missing in caption_extension Added warning on training without captions --- library/train_util.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 41afc13b..05ec7f84 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset): self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -1069,7 +1071,7 @@ class DreamBoothDataset(BaseDataset): assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break - return caption + return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): @@ -1081,16 +1083,33 @@ class DreamBoothDataset(BaseDataset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] + missing_captions = [] for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") captions.append("") else: - captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images") + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + print(missing_caption+f"... and {remaining_missing_captions} more") + break + print(missing_caption) + time.sleep(5) return img_paths, captions print("prepare images.") From d859a3a9259dd04d03a41816e17c8fd8bb0189ee Mon Sep 17 00:00:00 2001 From: TingTingin <36141041+TingTingin@users.noreply.github.com> Date: Tue, 23 May 2023 02:00:33 -0400 Subject: [PATCH 02/11] Update train_util.py fix mistake --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 05ec7f84..576fc5d8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1071,7 +1071,7 @@ class DreamBoothDataset(BaseDataset): assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break - return caption + return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): From 061e1571910d704e92e610bc48f90d9f21996afe Mon Sep 17 00:00:00 2001 From: TingTingin <36141041+TingTingin@users.noreply.github.com> Date: Tue, 23 May 2023 02:02:39 -0400 Subject: [PATCH 03/11] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 576fc5d8..55eeb316 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1071,7 +1071,7 @@ class DreamBoothDataset(BaseDataset): assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break - return caption + return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): From db756e9a343bcbc0efb2de077520233b70b5810a Mon Sep 17 00:00:00 2001 From: TingTingin <36141041+TingTingin@users.noreply.github.com> Date: Fri, 26 May 2023 08:08:34 -0400 Subject: [PATCH 04/11] Update train_util.py I removed the sleep since it triggers per subset and if someone had a lot of subsets it would trigger multiple times --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 55eeb316..09e6a366 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1109,7 +1109,6 @@ class DreamBoothDataset(BaseDataset): print(missing_caption+f"... and {remaining_missing_captions} more") break print(missing_caption) - time.sleep(5) return img_paths, captions print("prepare images.") From dd8e17cb37bcc9f5e57187a8c359082db40a5c8a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sat, 27 May 2023 05:15:02 +0900 Subject: [PATCH 05/11] =?UTF-8?q?=E5=B7=AE=E5=88=86=E5=AD=A6=E7=BF=92?= =?UTF-8?q?=E6=A9=9F=E8=83=BD=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f2fd2009..14084b67 100644 --- a/train_network.py +++ b/train_network.py @@ -148,6 +148,20 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # prepare network + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + if args.base_modules is not None: + for module_path in args.base_modules: + print("merging module: %s"%module_path) + module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu") + print("all modules merged: %s"%", ".join(args.base_modules)) + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -162,13 +176,6 @@ def train(args): accelerator.wait_for_everyone() - # prepare network - import sys - - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) + parser.add_argument( + "--base_modules", + type=str, default=None, nargs="*", + help="base modules for differential learning / 差分学習用のベースモデル", + ) + parser.add_argument( + "--base_modules_weight", + type=float, default=1, + help="weight of base modules for differential learning / 差分学習用のベースモデルの比重", + ) return parser From 990ceddd1499fb7e43a7ff6c986d7453a80afabb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 30 May 2023 22:53:50 +0900 Subject: [PATCH 06/11] show warning if no caption and no class token --- library/train_util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f241193c..d13c2d87 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1087,8 +1087,9 @@ class DreamBoothDataset(BaseDataset): for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}") captions.append("") + missing_captions.append(img_path) else: if cap_for_img is None: captions.append(subset.class_tokens) @@ -1103,10 +1104,12 @@ class DreamBoothDataset(BaseDataset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images") + print( + f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" + ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption+f"... and {remaining_missing_captions} more") + print(missing_caption + f"... and {remaining_missing_captions} more") break print(missing_caption) return img_paths, captions From fc006918985ed15f3e9b77bf8812b87b60703530 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 30 May 2023 23:10:41 +0900 Subject: [PATCH 07/11] enable multiple module weights --- train_network.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/train_network.py b/train_network.py index 14084b67..8525efd7 100644 --- a/train_network.py +++ b/train_network.py @@ -148,7 +148,7 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - # prepare network + # 差分追加学習のためにモデルを読み込む import sys sys.path.append(os.path.dirname(__file__)) @@ -156,11 +156,21 @@ def train(args): network_module = importlib.import_module(args.network_module) if args.base_modules is not None: - for module_path in args.base_modules: - print("merging module: %s"%module_path) - module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu") - print("all modules merged: %s"%", ".join(args.base_modules)) + # base_modules が指定されている場合は、指定されたモジュールを読み込みマージする + for i, module_path in enumerate(args.base_modules): + print(f"merging module: {module_path}") + + if args.base_modules_weights is None or len(args.base_modules_weights) <= i: + weight = 1.0 + else: + weight = args.base_modules_weights[i] + + module, weights_sd = network_module.create_network_from_weights( + weight, module_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + print(f"all modules merged: {', '.join(args.base_modules)}") # 学習を準備する if cache_latents: @@ -176,6 +186,7 @@ def train(args): accelerator.wait_for_everyone() + # prepare network net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -779,13 +790,17 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--base_modules", - type=str, default=None, nargs="*", + type=str, + default=None, + nargs="*", help="base modules for differential learning / 差分学習用のベースモデル", ) parser.add_argument( "--base_modules_weight", - type=float, default=1, - help="weight of base modules for differential learning / 差分学習用のベースモデルの比重", + type=float, + default=None, + nargs="*", + help="weights of base modules for differential learning / 差分学習用のベースモデルの比重", ) return parser From c437dce056e459dbee075f7343985b7c079b8fd5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 30 May 2023 23:19:29 +0900 Subject: [PATCH 08/11] change option name for merging network weights --- train_network.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/train_network.py b/train_network.py index 8525efd7..8c383b0b 100644 --- a/train_network.py +++ b/train_network.py @@ -155,22 +155,22 @@ def train(args): print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) - if args.base_modules is not None: - # base_modules が指定されている場合は、指定されたモジュールを読み込みマージする - for i, module_path in enumerate(args.base_modules): - print(f"merging module: {module_path}") + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + print(f"merging module: {weight_path}") - if args.base_modules_weights is None or len(args.base_modules_weights) <= i: - weight = 1.0 + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 else: - weight = args.base_modules_weights[i] + multiplier = args.base_weights_multiplier[i] module, weights_sd = network_module.create_network_from_weights( - weight, module_path, vae, text_encoder, unet, for_inference=True + multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - print(f"all modules merged: {', '.join(args.base_modules)}") + print(f"all weights merged: {', '.join(args.base_weights)}") # 学習を準備する if cache_latents: @@ -789,18 +789,18 @@ def setup_parser() -> argparse.ArgumentParser: help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) parser.add_argument( - "--base_modules", + "--base_weights", type=str, default=None, nargs="*", - help="base modules for differential learning / 差分学習用のベースモデル", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", ) parser.add_argument( - "--base_modules_weight", + "--base_weights_multiplier", type=float, default=None, nargs="*", - help="weights of base modules for differential learning / 差分学習用のベースモデルの比重", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) return parser From 6fbd52693128faff61707c24551ce5a3dbb8f682 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 31 May 2023 20:23:19 +0900 Subject: [PATCH 09/11] show multiplier for base weights to console --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 8c383b0b..c94d6afc 100644 --- a/train_network.py +++ b/train_network.py @@ -158,13 +158,13 @@ def train(args): if args.base_weights is not None: # base_weights が指定されている場合は、指定された重みを読み込みマージする for i, weight_path in enumerate(args.base_weights): - print(f"merging module: {weight_path}") - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: multiplier = 1.0 else: multiplier = args.base_weights_multiplier[i] + print(f"merging module: {weight_path} with multiplier {multiplier}") + module, weights_sd = network_module.create_network_from_weights( multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) From 3a0696833222507147bdb03915d39ddc6bde8751 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 31 May 2023 20:48:33 +0900 Subject: [PATCH 10/11] warn and continue if huggingface uploading failed --- library/huggingface_util.py | 52 ++++++++++++++++++++----------------- library/train_util.py | 4 ++- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 2d0e1980..1dc496ff 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -6,9 +6,7 @@ import os from library.utils import fire_in_thread -def exists_repo( - repo_id: str, repo_type: str, revision: str = "main", token: str = None -): +def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( token=token, ) @@ -32,27 +30,35 @@ def upload( private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" api = HfApi(token=token) if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): - api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので + print("===========================================") + print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + print("===========================================") - is_folder = (type(src) == str and os.path.isdir(src)) or ( - isinstance(src, Path) and src.is_dir() - ) + is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) def uploader(): - if is_folder: - api.upload_folder( - repo_id=repo_id, - repo_type=repo_type, - folder_path=src, - path_in_repo=path_in_repo, - ) - else: - api.upload_file( - repo_id=repo_id, - repo_type=repo_type, - path_or_fileobj=src, - path_in_repo=path_in_repo, - ) + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + print("===========================================") + print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + print("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) @@ -71,7 +77,5 @@ def list_dir( token=token, ) repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) - file_list = [ - file for file in repo_info.siblings if file.rfilename.startswith(subfolder) - ] + file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] return file_list diff --git a/library/train_util.py b/library/train_util.py index d13c2d87..d963537d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1087,7 +1087,9 @@ class DreamBoothDataset(BaseDataset): for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}") + print( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) captions.append("") missing_captions.append(img_path) else: From a002d10a4d9887ec4ecbc0d983063608646e31d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 31 May 2023 20:57:01 +0900 Subject: [PATCH 11/11] update readme --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index dbd1d17b..aefc6c35 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,24 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 31 May 2023, 2023/05/31 + +- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin! + - Warning is also displayed when using class+identifier dataset. Please ignore if it is intended. +- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru! + - `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge. + - `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used. + - This is useful for incremental learning. See PR for details. +- Show warning and continue training when uploading to HuggingFace fails. + +- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。 + - class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。 +- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。 + - `--base_weights` オプションでLoRA等のモデルファイル(複数可)を指定すると、それらの重みをマージします。 + - `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。 + - 差分追加学習などにご利用ください。詳細はPRをご覧ください。 +- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。 + ### 25 May 2023, 2023/05/25 - [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!