Add scaling alpha for LoRA

This commit is contained in:
Kohya S
2023-01-21 20:37:34 +09:00
parent 22ee0ac467
commit b4636d4185
3 changed files with 59 additions and 26 deletions

View File

@@ -107,7 +107,8 @@ def train(args):
key, value = net_arg.split('=')
net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
@@ -243,7 +244,8 @@ def train(args):
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
@@ -445,6 +447,8 @@ if __name__ == '__main__':
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_alpha", type=float, default=1,
help='alpha for LoRA weight scaling, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、0で調整なし旧バージョンと同じ')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")