From dd8e17cb37bcc9f5e57187a8c359082db40a5c8a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sat, 27 May 2023 05:15:02 +0900 Subject: [PATCH] =?UTF-8?q?=E5=B7=AE=E5=88=86=E5=AD=A6=E7=BF=92=E6=A9=9F?= =?UTF-8?q?=E8=83=BD=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f2fd2009..14084b67 100644 --- a/train_network.py +++ b/train_network.py @@ -148,6 +148,20 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # prepare network + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + if args.base_modules is not None: + for module_path in args.base_modules: + print("merging module: %s"%module_path) + module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu") + print("all modules merged: %s"%", ".join(args.base_modules)) + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -162,13 +176,6 @@ def train(args): accelerator.wait_for_everyone() - # prepare network - import sys - - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) + parser.add_argument( + "--base_modules", + type=str, default=None, nargs="*", + help="base modules for differential learning / 差分学習用のベースモデル", + ) + parser.add_argument( + "--base_modules_weight", + type=float, default=1, + help="weight of base modules for differential learning / 差分学習用のベースモデルの比重", + ) return parser