mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: added function to load training config with .toml
This commit is contained in:
23
train_db.py
23
train_db.py
@@ -7,6 +7,7 @@ import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -361,4 +362,24 @@ if __name__ == '__main__':
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
|
||||
train(args)
|
||||
Reference in New Issue
Block a user