From f6117263649078680661e319b9a469268a87b43b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 21:41:50 +0900 Subject: [PATCH 1/4] add network_merge_n_models option --- gen_img_diffusers.py | 55 ++++++++++++++++++++++++++++---------------- sdxl_gen_img.py | 54 ++++++++++++++++++++++++++++--------------- 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 0ec683a2..82002834 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,10 +65,13 @@ import re import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -954,7 +957,7 @@ class PipelineLike: text_emb_last = torch.stack(text_emb_last) else: text_emb_last = text_embeddings - + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) @@ -2363,12 +2366,19 @@ def main(args): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = None + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -2379,31 +2389,32 @@ def main(args): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs - ) - else: + if args.network_weights is None or len(args.network_weights) <= i: raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -2417,6 +2428,7 @@ def main(args): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to(text_encoder, unet, weights_sd, dtype, device) @@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab2b6b3d..2d652bc8 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -17,10 +17,13 @@ import re import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -1534,12 +1537,20 @@ def main(args): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = None + print(f"network_merge: {network_merge}") + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1550,31 +1561,32 @@ def main(args): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - else: + if args.network_weights is None or len(args.network_weights) <= i: raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -1588,6 +1600,7 @@ def main(args): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) @@ -2615,6 +2628,9 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" From 3e81bd6b6729bf8b553b30cbce23e21698287039 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 23:07:14 +0900 Subject: [PATCH 2/4] fix network_merge, add regional mask as color code --- gen_img_diffusers.py | 21 ++++++++++++++++++--- sdxl_gen_img.py | 21 ++++++++++++++++++--- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 82002834..a596a049 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2372,7 +2372,7 @@ def main(args): elif args.network_merge_n_models: network_merge = args.network_merge_n_models else: - network_merge = None + network_merge = 0 for i, network_module in enumerate(args.network_module): print("import network module:", network_module) @@ -2724,9 +2724,18 @@ def main(args): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -3386,6 +3395,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 2d652bc8..c31ae007 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1543,7 +1543,7 @@ def main(args): elif args.network_merge_n_models: network_merge = args.network_merge_n_models else: - network_merge = None + network_merge = 0 print(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): @@ -1877,9 +1877,18 @@ def main(args): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -2635,6 +2644,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str, From 17813ff5b4f38a79d64341c609d9941c7cb4172e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Oct 2023 07:40:12 +0900 Subject: [PATCH 3/4] remove workaround for transfomers bs>1 close #869 --- finetune/make_captions_by_git.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index ce6e6695..b3c5cc42 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch): def main(args): + r""" + transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように @@ -65,6 +68,7 @@ def main(args): return input_ids GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + """ print(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) @@ -81,7 +85,7 @@ def main(args): def run_batch(path_imgs): imgs = [im for _, im in path_imgs] - curr_batch_size[0] = len(path_imgs) + # curr_batch_size[0] = len(path_imgs) inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) From 681034d0014fc9c60bdc5687cadb2106b6f17b3f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Oct 2023 07:54:30 +0900 Subject: [PATCH 4/4] update readme --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 5da6181b..29c46187 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,21 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Oct 11, 2023 / 2023/10/11 +- Fix to work `make_captions_by_git.py` with the latest version of transformers. +- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options: + - `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works. + - `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions. + - When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7). + - You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff. + +- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。 +- `gen_img_diffusers.py` と `sdxl_gen_img.py` を更新し、以下のオプションを追加しました。 + - `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。 + - `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。 + - このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数(最大7)です。 + - 各LoRAに対してマスクをカラーで指定できます:0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。 + ### Oct 9. 2023 / 2023/10/9 - `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!