add network_merge_n_models option

This commit is contained in:
Kohya S
2023-10-09 21:41:50 +09:00
parent 8b79e3b06c
commit f611726364
2 changed files with 70 additions and 39 deletions

View File

@@ -65,10 +65,13 @@ import re
import diffusers import diffusers
import numpy as np import numpy as np
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -2363,12 +2366,19 @@ def main(args):
network_default_muls = [] network_default_muls = []
network_pre_calc = args.network_pre_calc network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = None
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -2379,31 +2389,32 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights): if args.network_weights is None or len(args.network_weights) <= i:
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None: if network is None:
return return
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if args.network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable: if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
@@ -2417,6 +2428,7 @@ def main(args):
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
network_default_muls.append(network_mul)
else: else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device) network.merge_to(text_encoder, unet, weights_sd, dtype, device)
@@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"

View File

@@ -17,10 +17,13 @@ import re
import diffusers import diffusers
import numpy as np import numpy as np
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -1534,12 +1537,20 @@ def main(args):
network_default_muls = [] network_default_muls = []
network_pre_calc = args.network_pre_calc network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = None
print(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -1550,31 +1561,32 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights): if args.network_weights is None or len(args.network_weights) <= i:
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None: if network is None:
return return
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if args.network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable: if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet) network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
@@ -1588,6 +1600,7 @@ def main(args):
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
network_default_muls.append(network_mul)
else: else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
@@ -2615,6 +2628,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"