mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add network_merge_n_models option
This commit is contained in:
@@ -65,10 +65,13 @@ import re
|
|||||||
import diffusers
|
import diffusers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -2363,12 +2366,19 @@ def main(args):
|
|||||||
network_default_muls = []
|
network_default_muls = []
|
||||||
network_pre_calc = args.network_pre_calc
|
network_pre_calc = args.network_pre_calc
|
||||||
|
|
||||||
|
# merge関連の引数を統合する
|
||||||
|
if args.network_merge:
|
||||||
|
network_merge = len(args.network_module) # all networks are merged
|
||||||
|
elif args.network_merge_n_models:
|
||||||
|
network_merge = args.network_merge_n_models
|
||||||
|
else:
|
||||||
|
network_merge = None
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
print("import network module:", network_module)
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_default_muls.append(network_mul)
|
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args and i < len(args.network_args):
|
if args.network_args and i < len(args.network_args):
|
||||||
@@ -2379,31 +2389,32 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
if args.network_weights and i < len(args.network_weights):
|
if args.network_weights is None or len(args.network_weights) <= i:
|
||||||
network_weight = args.network_weights[i]
|
|
||||||
print("load network weights from:", network_weight)
|
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
|
||||||
from safetensors.torch import safe_open
|
|
||||||
|
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
|
||||||
metadata = f.metadata()
|
|
||||||
if metadata is not None:
|
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
|
||||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
|
||||||
|
network_weight = args.network_weights[i]
|
||||||
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata is not None:
|
||||||
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
|
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||||
|
)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
mergeable = network.is_mergeable()
|
mergeable = network.is_mergeable()
|
||||||
if args.network_merge and not mergeable:
|
if network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
print("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not args.network_merge or not mergeable:
|
if not mergeable or i >= network_merge:
|
||||||
|
# not merging
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
@@ -2417,6 +2428,7 @@ def main(args):
|
|||||||
network.backup_weights()
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
|
network_default_muls.append(network_mul)
|
||||||
else:
|
else:
|
||||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||||
|
|
||||||
@@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||||
)
|
)
|
||||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||||
|
)
|
||||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||||
|
|||||||
@@ -17,10 +17,13 @@ import re
|
|||||||
import diffusers
|
import diffusers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -1534,12 +1537,20 @@ def main(args):
|
|||||||
network_default_muls = []
|
network_default_muls = []
|
||||||
network_pre_calc = args.network_pre_calc
|
network_pre_calc = args.network_pre_calc
|
||||||
|
|
||||||
|
# merge関連の引数を統合する
|
||||||
|
if args.network_merge:
|
||||||
|
network_merge = len(args.network_module) # all networks are merged
|
||||||
|
elif args.network_merge_n_models:
|
||||||
|
network_merge = args.network_merge_n_models
|
||||||
|
else:
|
||||||
|
network_merge = None
|
||||||
|
print(f"network_merge: {network_merge}")
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
print("import network module:", network_module)
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_default_muls.append(network_mul)
|
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args and i < len(args.network_args):
|
if args.network_args and i < len(args.network_args):
|
||||||
@@ -1550,31 +1561,32 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
if args.network_weights and i < len(args.network_weights):
|
if args.network_weights is None or len(args.network_weights) <= i:
|
||||||
network_weight = args.network_weights[i]
|
|
||||||
print("load network weights from:", network_weight)
|
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
|
||||||
from safetensors.torch import safe_open
|
|
||||||
|
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
|
||||||
metadata = f.metadata()
|
|
||||||
if metadata is not None:
|
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
|
||||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
|
||||||
|
network_weight = args.network_weights[i]
|
||||||
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata is not None:
|
||||||
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
|
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||||
|
)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
mergeable = network.is_mergeable()
|
mergeable = network.is_mergeable()
|
||||||
if args.network_merge and not mergeable:
|
if network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
print("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not args.network_merge or not mergeable:
|
if not mergeable or i >= network_merge:
|
||||||
|
# not merging
|
||||||
network.apply_to([text_encoder1, text_encoder2], unet)
|
network.apply_to([text_encoder1, text_encoder2], unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
@@ -1588,6 +1600,7 @@ def main(args):
|
|||||||
network.backup_weights()
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
|
network_default_muls.append(network_mul)
|
||||||
else:
|
else:
|
||||||
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
|
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
|
||||||
|
|
||||||
@@ -2615,6 +2628,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||||
)
|
)
|
||||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||||
|
)
|
||||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||||
|
|||||||
Reference in New Issue
Block a user