mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
@@ -3388,6 +3388,33 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
def filter_sensitive_args(args: argparse.Namespace):
|
||||
sensitive_args = ["wandb_api_key", "huggingface_token"]
|
||||
sensitive_path_args = [
|
||||
"pretrained_model_name_or_path",
|
||||
"vae",
|
||||
"tokenizer_cache_dir",
|
||||
"train_data_dir",
|
||||
"conditioning_data_dir",
|
||||
"reg_data_dir",
|
||||
"output_dir",
|
||||
"logging_dir",
|
||||
]
|
||||
filtered_args = {}
|
||||
for k, v in vars(args).items():
|
||||
# filter out sensitive values
|
||||
if k not in sensitive_args + sensitive_path_args:
|
||||
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
|
||||
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
|
||||
filtered_args[k] = v
|
||||
# accelerate does not support lists
|
||||
elif isinstance(v, list):
|
||||
filtered_args[k] = f"{v}"
|
||||
# accelerate does not support objects
|
||||
elif isinstance(v, object):
|
||||
filtered_args[k] = f"{v}"
|
||||
|
||||
return filtered_args
|
||||
|
||||
# verify command line args for training
|
||||
def verify_command_line_training_args(args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user