mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
enable multiple module weights
This commit is contained in:
@@ -148,7 +148,7 @@ def train(args):
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# prepare network
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
@@ -156,11 +156,21 @@ def train(args):
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_modules is not None:
|
||||
for module_path in args.base_modules:
|
||||
print("merging module: %s"%module_path)
|
||||
module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu")
|
||||
print("all modules merged: %s"%", ".join(args.base_modules))
|
||||
# base_modules が指定されている場合は、指定されたモジュールを読み込みマージする
|
||||
for i, module_path in enumerate(args.base_modules):
|
||||
print(f"merging module: {module_path}")
|
||||
|
||||
if args.base_modules_weights is None or len(args.base_modules_weights) <= i:
|
||||
weight = 1.0
|
||||
else:
|
||||
weight = args.base_modules_weights[i]
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
weight, module_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
print(f"all modules merged: {', '.join(args.base_modules)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -176,6 +186,7 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
@@ -779,13 +790,17 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules",
|
||||
type=str, default=None, nargs="*",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="base modules for differential learning / 差分学習用のベースモデル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules_weight",
|
||||
type=float, default=1,
|
||||
help="weight of base modules for differential learning / 差分学習用のベースモデルの比重",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="weights of base modules for differential learning / 差分学習用のベースモデルの比重",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user