Merge pull request #1233 from rockerBOO/lora-plus

Add LoRA+ support
This commit is contained in:
Kohya S
2024-04-29 18:05:12 +09:00
committed by GitHub
5 changed files with 220 additions and 74 deletions

View File

@@ -64,14 +64,61 @@ class NetworkTrainer:
lrs = lr_scheduler.get_last_lr()
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only:
logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only:
if len(lrs) > 4:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
lora_plus = ""
group_id = i
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
lora_plus = '_lora+' if i % 2 == 1 else ''
group_id = int((i / 2) + (i % 2 + 0.5))
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
else:
if args.network_train_text_encoder_only:
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])
elif args.network_train_unet_only:
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
else:
logs["lr/unet"] = float(lrs[0])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
if len(lrs) == 2:
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
logs["lr/all"] = float(lrs[0])
logs["lr/all_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1])
elif len(lrs) == 4:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
logs["lr/unet"] = float(lrs[2])
logs["lr/unet_lora+"] = float(lrs[3])
else:
logs["lr/all"] = float(lrs[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -79,18 +126,6 @@ class NetworkTrainer:
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
@@ -338,7 +373,7 @@ class NetworkTrainer:
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
@@ -347,6 +382,11 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
assert (
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
), "LoRA+ and Prodigy/DAdaptation is not supported"
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers