move loraplus args from args to network_args, simplify log lr desc

This commit is contained in:
Kohya S
2024-04-29 20:04:25 +09:00
parent 834445a1d6
commit 969f82ab47
3 changed files with 84 additions and 91 deletions

View File

@@ -53,7 +53,15 @@ class NetworkTrainer:
# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self,
args: argparse.Namespace,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -63,68 +71,25 @@ class NetworkTrainer:
logs["max_norm/max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr()
if len(lrs) > 4:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
lora_plus = ""
group_id = i
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
lora_plus = '_lora+' if i % 2 == 1 else ''
group_id = int((i / 2) + (i % 2 + 0.5))
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
else:
if args.network_train_text_encoder_only:
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])
elif args.network_train_unet_only:
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
else:
logs["lr/unet"] = float(lrs[0])
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
lr_desc = lr_descriptions[i]
else:
if len(lrs) == 2:
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
logs["lr/unet"] = float(lrs[0])
logs["lr/unet_lora+"] = float(lrs[1])
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
logs["lr/all"] = float(lrs[0])
logs["lr/all_lora+"] = float(lrs[1])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1])
elif len(lrs) == 4:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/textencoder_lora+"] = float(lrs[1])
logs["lr/unet"] = float(lrs[2])
logs["lr/unet_lora+"] = float(lrs[3])
idx = i - (0 if args.network_train_unet_only else -1)
if idx == -1:
lr_desc = "textencoder"
else:
logs["lr/all"] = float(lrs[0])
if len(lrs) > 2:
lr_desc = f"group{idx}"
else:
lr_desc = "unet"
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
logs[f"lr/{lr_desc}"] = lr
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
# tracking d*lr value
logs[f"lr/d*lr/{lr_desc}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
@@ -358,6 +323,7 @@ class NetworkTrainer:
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -373,20 +339,23 @@ class NetworkTrainer:
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
if type(results) is tuple:
trainable_params = results[0]
lr_descriptions = results[1]
else:
trainable_params = results
lr_descriptions = None
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
)
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
lr_descriptions = None
print(lr_descriptions)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
assert (
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
), "LoRA+ and Prodigy/DAdaptation is not supported"
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -992,7 +961,9 @@ class NetworkTrainer:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
@@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
return parser