diff --git a/train_network.py b/train_network.py index f2fd2009..14084b67 100644 --- a/train_network.py +++ b/train_network.py @@ -148,6 +148,20 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # prepare network + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + if args.base_modules is not None: + for module_path in args.base_modules: + print("merging module: %s"%module_path) + module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu") + print("all modules merged: %s"%", ".join(args.base_modules)) + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -162,13 +176,6 @@ def train(args): accelerator.wait_for_everyone() - # prepare network - import sys - - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) + parser.add_argument( + "--base_modules", + type=str, default=None, nargs="*", + help="base modules for differential learning / 差分学習用のベースモデル", + ) + parser.add_argument( + "--base_modules_weight", + type=float, default=1, + help="weight of base modules for differential learning / 差分学習用のベースモデルの比重", + ) return parser