feat: added function to load training config with .toml

This commit is contained in:
Linaqruf
2023-03-12 11:52:37 +07:00
parent 7c1cf7f4ea
commit 44d4cfb453
7 changed files with 96 additions and 4 deletions

View File

@@ -7,6 +7,7 @@ import argparse
import itertools
import math
import os
import toml
from tqdm import tqdm
import torch
@@ -361,4 +362,24 @@ if __name__ == '__main__':
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない")
args = parser.parse_args()
train(args)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args)