support individual LR for CLIP-L/T5XXL

This commit is contained in:
Kohya S
2024-09-10 20:32:09 +09:00
parent d29af146b8
commit d10ff62a78
3 changed files with 49 additions and 58 deletions

View File

@@ -466,9 +466,17 @@ class NetworkTrainer:
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ
# make backward compatibility for text_encoder_lr
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
if support_multiple_lrs:
text_encoder_lr = args.text_encoder_lr
else:
text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
try:
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
if support_multiple_lrs:
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
else:
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
if type(results) is tuple:
trainable_params = results[0]
lr_descriptions = results[1]
@@ -476,11 +484,7 @@ class NetworkTrainer:
trainable_params = results
lr_descriptions = None
except TypeError as e:
# logger.warning(f"{e}")
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
lr_descriptions = None
# if len(trainable_params) == 0:
@@ -713,7 +717,7 @@ class NetworkTrainer:
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_text_encoder_lr": text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_reg_images": train_dataset_group.num_reg_images,
@@ -760,8 +764,8 @@ class NetworkTrainer:
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base,
"ss_fp8_base_unet": args.fp8_base_unet,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument(
"--text_encoder_lr",
type=float,
default=None,
nargs="*",
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
)
parser.add_argument(
"--fp8_base_unet",
action="store_true",