feat: add arguments to set \--wandb_api_key\ before training

This commit is contained in:
Linaqruf
2023-04-22 16:14:14 +07:00
parent f256660780
commit ae3965a2a7

View File

@@ -2079,6 +2079,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument( parser.add_argument(
"--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名" "--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名"
) )
parser.add_argument(
"--wandb_api_key",
type=str,
default=None,
help="specify WandB API key to log in before starting training (optional)."
)
parser.add_argument( parser.add_argument(
"--noise_offset", "--noise_offset",
type=float, type=float,
@@ -2299,7 +2305,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
args_dict = vars(args) args_dict = vars(args)
# remove unnecessary keys # remove unnecessary keys
for key in ["config_file", "output_config"]: for key in ["config_file", "output_config", "wandb_api_key"]:
if key in args_dict: if key in args_dict:
del args_dict[key] del args_dict[key]
@@ -2761,6 +2767,8 @@ def prepare_accelerator(args: argparse.Namespace):
if logging_dir is not None: if logging_dir is not None:
os.makedirs(logging_dir, exist_ok=True) os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,