change option name for merging network weights

This commit is contained in:
Kohya S
2023-05-30 23:19:29 +09:00
parent fc00691898
commit c437dce056

View File

@@ -155,22 +155,22 @@ def train(args):
print("import network module:", args.network_module) print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module) network_module = importlib.import_module(args.network_module)
if args.base_modules is not None: if args.base_weights is not None:
# base_modules が指定されている場合は、指定されたモジュールを読み込みマージする # base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, module_path in enumerate(args.base_modules): for i, weight_path in enumerate(args.base_weights):
print(f"merging module: {module_path}") print(f"merging module: {weight_path}")
if args.base_modules_weights is None or len(args.base_modules_weights) <= i: if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
weight = 1.0 multiplier = 1.0
else: else:
weight = args.base_modules_weights[i] multiplier = args.base_weights_multiplier[i]
module, weights_sd = network_module.create_network_from_weights( module, weights_sd = network_module.create_network_from_weights(
weight, module_path, vae, text_encoder, unet, for_inference=True multiplier, weight_path, vae, text_encoder, unet, for_inference=True
) )
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all modules merged: {', '.join(args.base_modules)}") print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
@@ -789,18 +789,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
) )
parser.add_argument( parser.add_argument(
"--base_modules", "--base_weights",
type=str, type=str,
default=None, default=None,
nargs="*", nargs="*",
help="base modules for differential learning / 差分学習用のベースモデ", help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイ",
) )
parser.add_argument( parser.add_argument(
"--base_modules_weight", "--base_weights_multiplier",
type=float, type=float,
default=None, default=None,
nargs="*", nargs="*",
help="weights of base modules for differential learning / 差分学習用のベースモデルの比重", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
) )
return parser return parser