mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
improve
This commit is contained in:
@@ -41,12 +41,17 @@ from .train_util import (
|
||||
DatasetGroup,
|
||||
)
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||
parser.add_argument(
|
||||
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
||||
)
|
||||
|
||||
|
||||
# TODO: inherit Params class in Subset, Dataset
|
||||
@@ -60,6 +65,8 @@ class BaseSubsetParams:
|
||||
caption_separator: str = (",",)
|
||||
keep_tokens: int = 0
|
||||
keep_tokens_separator: str = (None,)
|
||||
secondary_separator: Optional[str] = None
|
||||
enable_wildcard: bool = False
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -181,6 +188,8 @@ class ConfigSanitizer:
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"secondary_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
@@ -247,9 +256,10 @@ class ConfigSanitizer:
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert (
|
||||
support_dreambooth or support_finetuning or support_controlnet
|
||||
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, (
|
||||
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
||||
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
||||
)
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -361,7 +371,9 @@ class ConfigSanitizer:
|
||||
return self.argparse_config_validator(argparse_namespace)
|
||||
except MultipleInvalid:
|
||||
# XXX: this should be a bug
|
||||
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
logger.error(
|
||||
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
||||
)
|
||||
raise
|
||||
|
||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||
@@ -447,7 +459,6 @@ class BlueprintGenerator:
|
||||
|
||||
return default_value
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
@@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
datasets.append(dataset)
|
||||
|
||||
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split <= 0.0:
|
||||
continue
|
||||
@@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
|
||||
val_datasets.append(dataset)
|
||||
|
||||
def print_info(_datasets):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""), " ")
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
print(info)
|
||||
logger.info(f'{info}')
|
||||
|
||||
print_info(datasets)
|
||||
# print validation info
|
||||
info = ""
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Validation Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f'{info}')
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
print("Validation dataset")
|
||||
print_info(val_datasets)
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
print(f"[Validation Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
@@ -562,8 +672,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
||||
tokens = name.split("_")
|
||||
@@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict:
|
||||
with open(file, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
elif file.name.lower().endswith(".toml"):
|
||||
try:
|
||||
config = toml.load(file)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||
@@ -675,13 +789,13 @@ if __name__ == "__main__":
|
||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||
|
||||
logger.info("[argparse_namespace]")
|
||||
logger.info(f'{vars(argparse_namespace)}')
|
||||
logger.info(f"{vars(argparse_namespace)}")
|
||||
|
||||
user_config = load_user_config(config_args.dataset_config)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[user_config]")
|
||||
logger.info(f'{user_config}')
|
||||
logger.info(f"{user_config}")
|
||||
|
||||
sanitizer = ConfigSanitizer(
|
||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||
@@ -690,10 +804,10 @@ if __name__ == "__main__":
|
||||
|
||||
logger.info("")
|
||||
logger.info("[sanitized_user_config]")
|
||||
logger.info(f'{sanitized_user_config}')
|
||||
logger.info(f"{sanitized_user_config}")
|
||||
|
||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[blueprint]")
|
||||
logger.info(f'{blueprint}')
|
||||
logger.info(f"{blueprint}")
|
||||
|
||||
Reference in New Issue
Block a user