差分学習機能追加

This commit is contained in:
u-haru
2023-05-27 05:15:02 +09:00
parent b6ba4cac83
commit dd8e17cb37

View File

@@ -148,6 +148,20 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# prepare network
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_modules is not None:
for module_path in args.base_modules:
print("merging module: %s"%module_path)
module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu")
print("all modules merged: %s"%", ".join(args.base_modules))
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@@ -162,13 +176,6 @@ def train(args):
accelerator.wait_for_everyone()
# prepare network
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
@@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--base_modules",
type=str, default=None, nargs="*",
help="base modules for differential learning / 差分学習用のベースモデル",
)
parser.add_argument(
"--base_modules_weight",
type=float, default=1,
help="weight of base modules for differential learning / 差分学習用のベースモデルの比重",
)
return parser