From 9b13444b9ce67b2e6201a125196597ecffefdd96 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 23 Feb 2023 21:35:47 +0900 Subject: [PATCH] raise error if options conflict --- library/train_util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 37642dd5..a02207b4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1372,8 +1372,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--optimizer_type", type=str, default="AdamW", - help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") + parser.add_argument("--optimizer_type", type=str, default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") # backward compatibility parser.add_argument("--use_8bit_adam", action="store_true", @@ -1532,11 +1532,16 @@ def get_optimizer(args, trainable_params): optimizer_type = args.optimizer_type if args.use_8bit_adam: - print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます") + assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" optimizer_type = "AdamW8bit" + elif args.use_lion_optimizer: - print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます") + assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" optimizer_type = "Lion" + + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() # 引数を分解する:boolとfloat、tupleのみ対応 @@ -1557,7 +1562,7 @@ def get_optimizer(args, trainable_params): value = tuple(value) optimizer_kwargs[key] = value - print("optkwargs:", optimizer_kwargs) + # print("optkwargs:", optimizer_kwargs) lr = args.learning_rate