mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support individual LR for CLIP-L/T5XXL
This commit is contained in:
@@ -466,9 +466,17 @@ class NetworkTrainer:
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 後方互換性を確保するよ
|
||||
# make backward compatibility for text_encoder_lr
|
||||
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
|
||||
if support_multiple_lrs:
|
||||
text_encoder_lr = args.text_encoder_lr
|
||||
else:
|
||||
text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
|
||||
try:
|
||||
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
if support_multiple_lrs:
|
||||
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
else:
|
||||
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
if type(results) is tuple:
|
||||
trainable_params = results[0]
|
||||
lr_descriptions = results[1]
|
||||
@@ -476,11 +484,7 @@ class NetworkTrainer:
|
||||
trainable_params = results
|
||||
lr_descriptions = None
|
||||
except TypeError as e:
|
||||
# logger.warning(f"{e}")
|
||||
# accelerator.print(
|
||||
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
# )
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
|
||||
lr_descriptions = None
|
||||
|
||||
# if len(trainable_params) == 0:
|
||||
@@ -713,7 +717,7 @@ class NetworkTrainer:
|
||||
"ss_training_started_at": training_started_at, # unix timestamp
|
||||
"ss_output_name": args.output_name,
|
||||
"ss_learning_rate": args.learning_rate,
|
||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||
"ss_text_encoder_lr": text_encoder_lr,
|
||||
"ss_unet_lr": args.unet_lr,
|
||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||
@@ -760,8 +764,8 @@ class NetworkTrainer:
|
||||
"ss_loss_type": args.loss_type,
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
"ss_fp8_base_unet": args.fp8_base_unet,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp8_base_unet",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user