refactor config parse, feature to output config

This commit is contained in:
Kohya S
2023-03-19 10:11:11 +09:00
parent c3f9eb10f1
commit 83e102c691
5 changed files with 75 additions and 77 deletions

View File

@@ -408,24 +408,6 @@ if __name__ == "__main__":
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args) train(args)

View File

@@ -3,6 +3,7 @@
import argparse import argparse
import importlib import importlib
import json import json
import pathlib
import re import re
import shutil import shutil
import time import time
@@ -23,6 +24,7 @@ import random
import hashlib import hashlib
import subprocess import subprocess
from io import BytesIO from io import BytesIO
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -1889,7 +1891,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類",
) )
parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") parser.add_argument(
"--config_file",
type=str,
default=None,
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
)
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)
if support_dreambooth: if support_dreambooth:
# DreamBooth training # DreamBooth training
@@ -2016,6 +2026,66 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
) )
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
return args
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if args.output_config:
# check if config file exists
if os.path.exists(config_path):
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
exit(1)
# convert args to dictionary
args_dict = vars(args)
# remove unnecessary keys
for key in ["config_file", "output_config"]:
if key in args_dict:
del args_dict[key]
# convert Path to str in dictionary
for key, value in args_dict.items():
if isinstance(value, pathlib.Path):
args_dict[key] = str(value)
# convert to toml and output to file
with open(config_path, "w") as f:
toml.dump(args_dict, f)
print(f"Saved config file / 設定ファイルを保存しました: {config_path}")
exit(0)
if not os.path.exists(config_path):
print(f"{config_path} not found.")
exit(1)
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
# combine all sections into one
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
# if value is not dict, save key and value as is
if not isinstance(section_dict, dict):
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = os.path.splitext(args.config_file)[0]
print(args.config_file)
return args
# endregion # endregion
# region utils # region utils

View File

@@ -411,24 +411,6 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args) train(args)

View File

@@ -695,24 +695,6 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args) train(args)

View File

@@ -573,24 +573,6 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args) train(args)