diff --git a/train_network.py b/train_network.py index 8525efd7..8c383b0b 100644 --- a/train_network.py +++ b/train_network.py @@ -155,22 +155,22 @@ def train(args): print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) - if args.base_modules is not None: - # base_modules が指定されている場合は、指定されたモジュールを読み込みマージする - for i, module_path in enumerate(args.base_modules): - print(f"merging module: {module_path}") + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + print(f"merging module: {weight_path}") - if args.base_modules_weights is None or len(args.base_modules_weights) <= i: - weight = 1.0 + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 else: - weight = args.base_modules_weights[i] + multiplier = args.base_weights_multiplier[i] module, weights_sd = network_module.create_network_from_weights( - weight, module_path, vae, text_encoder, unet, for_inference=True + multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - print(f"all modules merged: {', '.join(args.base_modules)}") + print(f"all weights merged: {', '.join(args.base_weights)}") # 学習を準備する if cache_latents: @@ -789,18 +789,18 @@ def setup_parser() -> argparse.ArgumentParser: help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) parser.add_argument( - "--base_modules", + "--base_weights", type=str, default=None, nargs="*", - help="base modules for differential learning / 差分学習用のベースモデル", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", ) parser.add_argument( - "--base_modules_weight", + "--base_weights_multiplier", type=float, default=None, nargs="*", - help="weights of base modules for differential learning / 差分学習用のベースモデルの比重", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) return parser