Merge branch 'sd3' into val-loss-improvement

This commit is contained in:
Kohya S
2025-02-18 21:34:30 +09:00
3 changed files with 36 additions and 10 deletions

View File

@@ -148,10 +148,11 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(
paths: List[str],
sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
) -> List[str]:
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
"""
Split the dataset into train and validation
@@ -160,22 +161,28 @@ def split_train_val(
[0:80] = 80 training images
[80:] = 20 validation images
"""
dataset = list(zip(paths, sizes))
if validation_seed is not None:
logging.info(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.shuffle(dataset)
random.setstate(prevstate)
else:
random.shuffle(paths)
random.shuffle(dataset)
paths, sizes = zip(*dataset)
paths = list(paths)
sizes = list(sizes)
# Split the dataset between training and validation
if is_training_dataset:
# Training dataset we split to the first part
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
split = math.ceil(len(paths) * (1 - validation_split))
return paths[0:split], sizes[0:split]
else:
# Validation dataset we split to the second part
return paths[len(paths) - round(len(paths) * validation_split):]
split = len(paths) - round(len(paths) * validation_split)
return paths[split:], sizes[split:]
class ImageInfo:
@@ -1931,12 +1938,12 @@ class DreamBoothDataset(BaseDataset):
with open(info_cache_file, "r", encoding="utf-8") as f:
metas = json.load(f)
img_paths = list(metas.keys())
sizes = [meta["resolution"] for meta in metas.values()]
sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()]
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None] * len(img_paths)
sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths)
# new caching: get image size from cache files
strategy = LatentsCachingStrategy.get_strategy()
@@ -1969,7 +1976,7 @@ class DreamBoothDataset(BaseDataset):
w, h = None, None
if w is not None and h is not None:
sizes[i] = [w, h]
sizes[i] = (w, h)
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
@@ -1987,11 +1994,13 @@ class DreamBoothDataset(BaseDataset):
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
sizes = []
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths = split_train_val(
img_paths, sizes = split_train_val(
img_paths,
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed