mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
verify command line args if wandb is enabled
This commit is contained in:
@@ -520,6 +520,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -3358,6 +3358,60 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# verify command line args for training
|
||||||
|
def verify_command_line_training_args(args: argparse.Namespace):
|
||||||
|
# if wandb is enabled, the command line is exposed to the public
|
||||||
|
# check whether sensitive options are included in the command line arguments
|
||||||
|
# if so, warn or inform the user to move them to the configuration file
|
||||||
|
# wandbが有効な場合、コマンドラインが公開される
|
||||||
|
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
|
||||||
|
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する
|
||||||
|
|
||||||
|
wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
|
||||||
|
if not wandb_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
sensitive_args = ["wandb_api_key", "huggingface_token"]
|
||||||
|
sensitive_path_args = [
|
||||||
|
"pretrained_model_name_or_path",
|
||||||
|
"vae",
|
||||||
|
"tokenizer_cache_dir",
|
||||||
|
"train_data_dir",
|
||||||
|
"conditioning_data_dir",
|
||||||
|
"reg_data_dir",
|
||||||
|
"output_dir",
|
||||||
|
"logging_dir",
|
||||||
|
]
|
||||||
|
|
||||||
|
for arg in sensitive_args:
|
||||||
|
if getattr(args, arg, None) is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
|
||||||
|
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# if path is absolute, it may include sensitive information
|
||||||
|
for arg in sensitive_path_args:
|
||||||
|
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
|
||||||
|
logger.info(
|
||||||
|
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
|
||||||
|
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(args, "config_file", None) is not None:
|
||||||
|
logger.info(
|
||||||
|
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
|
||||||
|
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# other sensitive options
|
||||||
|
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
|
||||||
|
logger.info(
|
||||||
|
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
|
||||||
|
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
r"""
|
r"""
|
||||||
Verify training arguments. Also reflect highvram option to global variable
|
Verify training arguments. Also reflect highvram option to global variable
|
||||||
|
|||||||
@@ -812,6 +812,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -612,6 +612,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -580,6 +580,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
trainer = SdxlNetworkTrainer()
|
trainer = SdxlNetworkTrainer()
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
trainer = SdxlTextualInversionTrainer()
|
trainer = SdxlTextualInversionTrainer()
|
||||||
|
|||||||
@@ -617,6 +617,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -523,6 +523,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -1101,6 +1101,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
trainer = NetworkTrainer()
|
trainer = NetworkTrainer()
|
||||||
|
|||||||
@@ -806,6 +806,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
trainer = TextualInversionTrainer()
|
trainer = TextualInversionTrainer()
|
||||||
|
|||||||
@@ -714,6 +714,7 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user