passing filtered hyperparameters to accelerate

This commit is contained in:
Maatra
2024-04-20 14:11:43 +01:00
parent 71e2c91330
commit 2c9db5d9f2
10 changed files with 23 additions and 9 deletions

View File

@@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
def filter_sensitive_args(args: argparse.Namespace):
sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
"vae",
"tokenizer_cache_dir",
"train_data_dir",
"conditioning_data_dir",
"reg_data_dir",
"output_dir",
"logging_dir",
]
filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args}
return filtered_args
# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):