mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge pull request #241 from Linaqruf/main
Load training arguments from .yaml, and other small changes
This commit is contained in:
21
fine_tune.py
21
fine_tune.py
@@ -5,6 +5,7 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -362,4 +363,24 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
|
||||
import os
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
@@ -29,6 +29,9 @@ def main(args):
|
||||
caption_path = image_path.with_suffix(args.caption_extension)
|
||||
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||
|
||||
if not os.path.exists(caption_path):
|
||||
caption_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
|
||||
import os
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
@@ -29,6 +29,9 @@ def main(args):
|
||||
tags_path = image_path.with_suffix(args.caption_extension)
|
||||
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||
|
||||
if not os.path.exists(tags_path):
|
||||
tags_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
@@ -1598,6 +1598,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
'dpmsolver++', 'dpmsingle',
|
||||
'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'],
|
||||
help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類')
|
||||
|
||||
parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
|
||||
23
train_db.py
23
train_db.py
@@ -7,6 +7,7 @@ import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -361,4 +362,24 @@ if __name__ == '__main__':
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
|
||||
train(args)
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -656,4 +657,24 @@ if __name__ == '__main__':
|
||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
|
||||
train(args)
|
||||
@@ -3,6 +3,7 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -523,4 +524,24 @@ if __name__ == '__main__':
|
||||
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config_file:
|
||||
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
||||
if os.path.exists(config_path):
|
||||
print(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
ignore_nesting_dict = {}
|
||||
for section_name, section_dict in config_dict.items():
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = args.config_file.split(".")[0]
|
||||
print(args.config_file)
|
||||
else:
|
||||
print(f"{config_path} not found.")
|
||||
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user