diff --git a/train_network.py b/train_network.py index 0fd8cc0f..9956a905 100644 --- a/train_network.py +++ b/train_network.py @@ -213,7 +213,13 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + # 後方互換性を確保するよ + try: + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + except TypeError: + print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)") + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する