diff --git a/train_network.py b/train_network.py index 14084b67..8525efd7 100644 --- a/train_network.py +++ b/train_network.py @@ -148,7 +148,7 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - # prepare network + # 差分追加学習のためにモデルを読み込む import sys sys.path.append(os.path.dirname(__file__)) @@ -156,11 +156,21 @@ def train(args): network_module = importlib.import_module(args.network_module) if args.base_modules is not None: - for module_path in args.base_modules: - print("merging module: %s"%module_path) - module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu") - print("all modules merged: %s"%", ".join(args.base_modules)) + # base_modules が指定されている場合は、指定されたモジュールを読み込みマージする + for i, module_path in enumerate(args.base_modules): + print(f"merging module: {module_path}") + + if args.base_modules_weights is None or len(args.base_modules_weights) <= i: + weight = 1.0 + else: + weight = args.base_modules_weights[i] + + module, weights_sd = network_module.create_network_from_weights( + weight, module_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + print(f"all modules merged: {', '.join(args.base_modules)}") # 学習を準備する if cache_latents: @@ -176,6 +186,7 @@ def train(args): accelerator.wait_for_everyone() + # prepare network net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -779,13 +790,17 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--base_modules", - type=str, default=None, nargs="*", + type=str, + default=None, + nargs="*", help="base modules for differential learning / 差分学習用のベースモデル", ) parser.add_argument( "--base_modules_weight", - type=float, default=1, - help="weight of base modules for differential learning / 差分学習用のベースモデルの比重", + type=float, + default=None, + nargs="*", + help="weights of base modules for differential learning / 差分学習用のベースモデルの比重", ) return parser