From 4671e237781dcfe9a16e90f5343afd57586a1df6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:42:44 -0500 Subject: [PATCH 1/4] Fix validation epoch loss to check epoch average --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index b5f92e06..674f1cb6 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 3c7496ae3f2736a8283a881f49698d3e8f3a4291 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:18:14 -0500 Subject: [PATCH 2/4] Fix sizes for validation split --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a99..6c782ea1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -148,10 +148,11 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" def split_train_val( paths: List[str], + sizes: List[Optional[Tuple[int, int]]], is_training_dataset: bool, validation_split: float, validation_seed: int | None -) -> List[str]: +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -172,10 +173,12 @@ def split_train_val( # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part - return paths[0:math.ceil(len(paths) * (1 - validation_split))] + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] else: # Validation dataset we split to the second part - return paths[len(paths) - round(len(paths) * validation_split):] + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -1931,12 +1934,12 @@ class DreamBoothDataset(BaseDataset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() @@ -1969,7 +1972,7 @@ class DreamBoothDataset(BaseDataset): w, h = None, None if w is not None and h is not None: - sizes[i] = [w, h] + sizes[i] = (w, h) size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") @@ -1990,8 +1993,9 @@ class DreamBoothDataset(BaseDataset): # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: - img_paths = split_train_val( + img_paths, sizes = split_train_val( img_paths, + sizes, self.is_training_dataset, self.validation_split, self.validation_seed From f3a010978c0e4b88c4839b3a81400b8973f52158 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:28:34 -0500 Subject: [PATCH 3/4] Clear sizes for validation reg images to be consistent --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 6c782ea1..39b4af85 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1990,6 +1990,7 @@ class DreamBoothDataset(BaseDataset): # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] + sizes = [] # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: From 9436b410617f22716eac64f7c604c8f53fa8c1a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 17 Feb 2025 14:28:41 -0500 Subject: [PATCH 4/4] Fix validation split and add test --- library/train_util.py | 8 ++++++-- tests/test_validation.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/test_validation.py diff --git a/library/train_util.py b/library/train_util.py index 39b4af85..b2329066 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -161,15 +161,19 @@ def split_train_val( [0:80] = 80 training images [80:] = 20 validation images """ + dataset = list(zip(paths, sizes)) if validation_seed is not None: logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) - random.shuffle(paths) + random.shuffle(dataset) random.setstate(prevstate) else: - random.shuffle(paths) + random.shuffle(dataset) + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..f80686d8 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val()