mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge remote-tracking branch 'hina/feature/val-loss' into validation-loss-upstream
Modified implementation for process_batch and cleanup validation recording
This commit is contained in:
@@ -10,13 +10,7 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
# from toolz import curry
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import toml
|
||||
import voluptuous
|
||||
@@ -78,6 +72,9 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,11 +83,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,14 +101,13 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -191,11 +189,13 @@ class ConfigSanitizer:
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"secondary_separator": str,
|
||||
"caption_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -212,11 +212,13 @@ class ConfigSanitizer:
|
||||
DB_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
"is_reg": bool,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
# FT means FineTuning
|
||||
FT_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("metadata_file"): str,
|
||||
"image_dir": str,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
@@ -480,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
datasets.append(dataset)
|
||||
|
||||
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
@@ -488,17 +490,17 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
if dataset_blueprint.params.validation_split <= 0.0:
|
||||
continue
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
else:
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
val_datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
@@ -543,6 +545,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
@@ -564,6 +568,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
print("Validation dataset")
|
||||
print_info(val_datasets)
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
info = ""
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Validation Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Validation Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f"{info}")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
@@ -574,7 +622,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset.set_seed(seed)
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
print(f"[Validation Dataset {i}]")
|
||||
logger.info(f"[Validation Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user