add --log_config option to enable/disable output training config

This commit is contained in:
Kohya S
2024-05-19 17:21:04 +09:00
parent 47187f7079
commit c68baae480
11 changed files with 42 additions and 16 deletions

View File

@@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインするオプション",
)
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
parser.add_argument(
"--noise_offset",
@@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
def filter_sensitive_args(args: argparse.Namespace):
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
# when `--log_config is enabled, filter out sensitive values from args
# if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
if not args.log_config:
return None
sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
@@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace):
]
filtered_args = {}
for k, v in vars(args).items():
# filter out sensitive values
# filter out sensitive values and convert to string if necessary
if k not in sensitive_args + sensitive_path_args:
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
# Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
filtered_args[k] = v
# accelerate does not support lists
@@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace):
return filtered_args
# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
# if wandb is enabled, the command line is exposed to the public