add backward compatiblity

This commit is contained in:
Kohya S
2023-04-04 08:27:11 +09:00
parent 0fcdda7175
commit 76bac2c1c5

View File

@@ -213,7 +213,13 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する