mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add network_merge_n_models option
This commit is contained in:
@@ -65,10 +65,13 @@ import re
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -954,7 +957,7 @@ class PipelineLike:
|
||||
text_emb_last = torch.stack(text_emb_last)
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
@@ -2363,12 +2366,19 @@ def main(args):
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
# merge関連の引数を統合する
|
||||
if args.network_merge:
|
||||
network_merge = len(args.network_module) # all networks are merged
|
||||
elif args.network_merge_n_models:
|
||||
network_merge = args.network_merge_n_models
|
||||
else:
|
||||
network_merge = None
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -2379,31 +2389,32 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
if args.network_merge and not mergeable:
|
||||
if network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergeable:
|
||||
if not mergeable or i >= network_merge:
|
||||
# not merging
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -2417,6 +2428,7 @@ def main(args):
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
network_default_muls.append(network_mul)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
|
||||
@@ -3367,6 +3379,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
|
||||
Reference in New Issue
Block a user