diff --git a/fine_tune.py b/fine_tune.py index 12557597..20f94cd4 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,6 +5,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -362,4 +363,24 @@ if __name__ == '__main__': parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + train(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index cbc5033f..491e4591 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 4285feb0..8823a9c8 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/library/train_util.py b/library/train_util.py index 718fe36d..248156a3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1598,6 +1598,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') + + parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter") if support_dreambooth: # DreamBooth training diff --git a/train_db.py b/train_db.py index a3021177..5fd3c65b 100644 --- a/train_db.py +++ b/train_db.py @@ -7,6 +7,7 @@ import argparse import itertools import math import os +import toml from tqdm import tqdm import torch @@ -361,4 +362,24 @@ if __name__ == '__main__': help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") args = parser.parse_args() - train(args) + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) \ No newline at end of file diff --git a/train_network.py b/train_network.py index 5aa8af48..454bd254 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import os import random import time import json +import toml from tqdm import tqdm import torch @@ -656,4 +657,24 @@ if __name__ == '__main__': help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() - train(args) + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + + train(args) \ No newline at end of file diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 34b7f092..7cfaedfe 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -3,6 +3,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -523,4 +524,24 @@ if __name__ == '__main__': help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") args = parser.parse_args() + + if args.config_file: + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + if os.path.exists(config_path): + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = args.config_file.split(".")[0] + print(args.config_file) + else: + print(f"{config_path} not found.") + train(args)