From ae3965a2a7b46586f7da50c25d4b1cc772ab1ad7 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Sat, 22 Apr 2023 16:14:14 +0700 Subject: [PATCH] feat: add arguments to set \--wandb_api_key\ before training --- library/train_util.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d43e0075..8ec0ccbf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2079,6 +2079,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名" ) + parser.add_argument( + "--wandb_api_key", + type=str, + default=None, + help="specify WandB API key to log in before starting training (optional)." + ) parser.add_argument( "--noise_offset", type=float, @@ -2299,7 +2305,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar args_dict = vars(args) # remove unnecessary keys - for key in ["config_file", "output_config"]: + for key in ["config_file", "output_config", "wandb_api_key"]: if key in args_dict: del args_dict[key] @@ -2761,6 +2767,8 @@ def prepare_accelerator(args: argparse.Namespace): if logging_dir is not None: os.makedirs(logging_dir, exist_ok=True) os.environ["WANDB_DIR"] = logging_dir + if args.wandb_api_key is not None: + wandb.login(key=args.wandb_api_key) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps,