Compare commits

...

6 Commits

Author SHA1 Message Date
Kohya S
2a23713f71 Merge pull request #872 from kohya-ss/dev
fix make_captions_by_git, improve image generation scripts
2023-10-11 07:56:39 +09:00
Kohya S
681034d001 update readme 2023-10-11 07:54:30 +09:00
Kohya S
17813ff5b4 remove workaround for transfomers bs>1 close #869 2023-10-11 07:40:12 +09:00
Kohya S
3e81bd6b67 fix network_merge, add regional mask as color code 2023-10-09 23:07:14 +09:00
Kohya S
23ae358e0f Merge branch 'main' into dev 2023-10-09 21:42:13 +09:00
Kohya S
f611726364 add network_merge_n_models option 2023-10-09 21:41:50 +09:00
4 changed files with 124 additions and 44 deletions

View File

@@ -249,6 +249,21 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History ## Change History
### Oct 11, 2023 / 2023/10/11
- Fix to work `make_captions_by_git.py` with the latest version of transformers.
- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options:
- `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works.
- `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions.
- When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7).
- You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff.
- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。
- `gen_img_diffusers.py``sdxl_gen_img.py` を更新し、以下のオプションを追加しました。
- `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。
- `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。
- このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数最大7です。
- 各LoRAに対してマスクをカラーで指定できます0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。
### Oct 9. 2023 / 2023/10/9 ### Oct 9. 2023 / 2023/10/9
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py! - `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!

View File

@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
def main(args): def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
@@ -65,6 +68,7 @@ def main(args):
return input_ids return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
@@ -81,7 +85,7 @@ def main(args):
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = [im for _, im in path_imgs] imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs) # curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)

View File

@@ -65,10 +65,13 @@ import re
import diffusers import diffusers
import numpy as np import numpy as np
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -2363,12 +2366,19 @@ def main(args):
network_default_muls = [] network_default_muls = []
network_pre_calc = args.network_pre_calc network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -2379,31 +2389,32 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights): if args.network_weights is None or len(args.network_weights) <= i:
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None: if network is None:
return return
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if args.network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable: if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
@@ -2417,6 +2428,7 @@ def main(args):
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
network_default_muls.append(network_mul)
else: else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device) network.merge_to(text_encoder, unet, weights_sd, dtype, device)
@@ -2712,9 +2724,18 @@ def main(args):
size = None size = None
for i, network in enumerate(networks): for i, network in enumerate(networks):
if i < 3: if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0]) np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape size = np_mask.shape
else: else:
np_mask = np.full(size, 255, dtype=np.uint8) np_mask = np.full(size, 255, dtype=np.uint8)
@@ -3367,10 +3388,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
) )
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument( parser.add_argument(
"--textual_inversion_embeddings", "--textual_inversion_embeddings",
type=str, type=str,

View File

@@ -17,10 +17,13 @@ import re
import diffusers import diffusers
import numpy as np import numpy as np
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -1534,12 +1537,20 @@ def main(args):
network_default_muls = [] network_default_muls = []
network_pre_calc = args.network_pre_calc network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -1550,31 +1561,32 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights): if args.network_weights is None or len(args.network_weights) <= i:
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None: if network is None:
return return
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if args.network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable: if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet) network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
@@ -1588,6 +1600,7 @@ def main(args):
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
network_default_muls.append(network_mul)
else: else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
@@ -1864,9 +1877,18 @@ def main(args):
size = None size = None
for i, network in enumerate(networks): for i, network in enumerate(networks):
if i < 3: if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0]) np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape size = np_mask.shape
else: else:
np_mask = np.full(size, 255, dtype=np.uint8) np_mask = np.full(size, 255, dtype=np.uint8)
@@ -2615,10 +2637,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
) )
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument( parser.add_argument(
"--textual_inversion_embeddings", "--textual_inversion_embeddings",
type=str, type=str,