implement stratified_lr

This commit is contained in:
u-haru
2023-03-31 00:39:35 +09:00
parent b1dffe8d9a
commit 4dacc52bde
2 changed files with 125 additions and 18 deletions

View File

@@ -191,7 +191,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する