mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Add validation split of datasets
This commit is contained in:
@@ -85,6 +85,8 @@ class BaseDatasetParams:
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -200,6 +202,8 @@ class ConfigSanitizer:
|
||||
"enable_bucket": bool,
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
}
|
||||
|
||||
@@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
|
||||
datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split <= 0.0:
|
||||
continue
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
else:
|
||||
info += "\n"
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
|
||||
val_datasets.append(dataset)
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
# print info
|
||||
def print_info(_datasets):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
|
||||
print(info)
|
||||
print(info)
|
||||
|
||||
print_info(datasets)
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
print("Validation dataset")
|
||||
print_info(val_datasets)
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
@@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
print(f"[Validation Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
|
||||
@@ -123,6 +123,22 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
if is_train:
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
@@ -1314,6 +1330,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
is_train: bool,
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
@@ -1324,12 +1341,18 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
debug_dataset,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.is_train = is_train
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
@@ -1382,6 +1405,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
|
||||
@@ -189,10 +189,11 @@ class NetworkTrainer:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -212,6 +213,10 @@ class NetworkTrainer:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
@@ -264,6 +269,9 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
if val_dataset_group is not None:
|
||||
print("Cache validation latents...")
|
||||
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -345,61 +353,8 @@ class NetworkTrainer:
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
|
||||
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
|
||||
|
||||
from typing import Sequence, Union
|
||||
from torch._utils import _accumulate
|
||||
import warnings
|
||||
from torch.utils.data.dataset import Subset
|
||||
|
||||
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
|
||||
indices = get_indices_without_reg(dataset)
|
||||
random.shuffle(indices)
|
||||
|
||||
subset_lengths = []
|
||||
|
||||
for i, frac in enumerate(lengths):
|
||||
if frac < 0 or frac > 1:
|
||||
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
|
||||
n_items_in_split = int(math.floor(len(indices) * frac))
|
||||
subset_lengths.append(n_items_in_split)
|
||||
|
||||
remainder = len(indices) - sum(subset_lengths)
|
||||
|
||||
for i in range(remainder):
|
||||
idx_to_add_at = i % len(subset_lengths)
|
||||
subset_lengths[idx_to_add_at] += 1
|
||||
|
||||
lengths = subset_lengths
|
||||
for i, length in enumerate(lengths):
|
||||
if length == 0:
|
||||
warnings.warn(f"Length of split at index {i} is 0. "
|
||||
f"This might result in an empty dataset.")
|
||||
|
||||
if sum(lengths) != len(indices):
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||
|
||||
|
||||
if args.validation_ratio > 0.0:
|
||||
train_ratio = 1 - args.validation_ratio
|
||||
validation_ratio = args.validation_ratio
|
||||
train, val = random_split(
|
||||
train_dataset_group,
|
||||
[train_ratio, validation_ratio]
|
||||
)
|
||||
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
|
||||
print(f"train images: {len(train)}, validation images: {len(val)}")
|
||||
else:
|
||||
train = train_dataset_group
|
||||
val = []
|
||||
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train,
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
@@ -408,7 +363,7 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val,
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
|
||||
Reference in New Issue
Block a user