mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix regularization images with validation
Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention
This commit is contained in:
@@ -146,7 +146,12 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
|
||||
@@ -1830,6 +1835,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
@@ -1965,8 +1973,29 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
#
|
||||
# We split the dataset for the subset based on if we are doing a validation split
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# if self.is_training_dataset is True -> training dataset
|
||||
# if self.is_training_dataset is False -> validation dataset
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
|
||||
# For regularization images we do not want to split this dataset.
|
||||
if subset.is_reg is True:
|
||||
# Skip any validation dataset for regularization images
|
||||
if self.is_training_dataset is False:
|
||||
img_paths = []
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths = split_train_val(
|
||||
img_paths,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user