enable multiple module weights

This commit is contained in:
Kohya S
2023-05-30 23:10:41 +09:00
parent 990ceddd14
commit fc00691898

View File

@@ -148,7 +148,7 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# prepare network
# 差分追加学習のためにモデルを読み込む
import sys
sys.path.append(os.path.dirname(__file__))
@@ -156,11 +156,21 @@ def train(args):
network_module = importlib.import_module(args.network_module)
if args.base_modules is not None:
for module_path in args.base_modules:
print("merging module: %s"%module_path)
module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu")
print("all modules merged: %s"%", ".join(args.base_modules))
# base_modules が指定されている場合は、指定されたモジュールを読み込みマージする
for i, module_path in enumerate(args.base_modules):
print(f"merging module: {module_path}")
if args.base_modules_weights is None or len(args.base_modules_weights) <= i:
weight = 1.0
else:
weight = args.base_modules_weights[i]
module, weights_sd = network_module.create_network_from_weights(
weight, module_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all modules merged: {', '.join(args.base_modules)}")
# 学習を準備する
if cache_latents:
@@ -176,6 +186,7 @@ def train(args):
accelerator.wait_for_everyone()
# prepare network
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
@@ -779,13 +790,17 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--base_modules",
type=str, default=None, nargs="*",
type=str,
default=None,
nargs="*",
help="base modules for differential learning / 差分学習用のベースモデル",
)
parser.add_argument(
"--base_modules_weight",
type=float, default=1,
help="weight of base modules for differential learning / 差分学習用のベースモデルの比重",
type=float,
default=None,
nargs="*",
help="weights of base modules for differential learning / 差分学習用のベースモデルの比重",
)
return parser