Merge pull request #241 from Linaqruf/main

Load training arguments from .yaml, and other small changes
This commit is contained in:
Kohya S
2023-03-18 18:50:42 +09:00
committed by GitHub
7 changed files with 96 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ import argparse
import gc
import math
import os
import toml
from tqdm import tqdm
import torch
@@ -362,4 +363,24 @@ if __name__ == '__main__':
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args)

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@@ -29,6 +29,9 @@ def main(args):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@@ -29,6 +29,9 @@ def main(args):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}

View File

@@ -1598,6 +1598,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
'dpmsolver++', 'dpmsingle',
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter")
if support_dreambooth:
# DreamBooth training

View File

@@ -7,6 +7,7 @@ import argparse
import itertools
import math
import os
import toml
from tqdm import tqdm
import torch
@@ -361,4 +362,24 @@ if __name__ == '__main__':
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない")
args = parser.parse_args()
train(args)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args)

View File

@@ -7,6 +7,7 @@ import os
import random
import time
import json
import toml
from tqdm import tqdm
import torch
@@ -656,4 +657,24 @@ if __name__ == '__main__':
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
args = parser.parse_args()
train(args)
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args)

View File

@@ -3,6 +3,7 @@ import argparse
import gc
import math
import os
import toml
from tqdm import tqdm
import torch
@@ -523,4 +524,24 @@ if __name__ == '__main__':
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する")
args = parser.parse_args()
if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)
ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")
train(args)