diff --git a/library/train_util.py b/library/train_util.py index 7c6dbbdd..36ded89d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2736,6 +2736,29 @@ class DatasetGroup(torch.utils.data.ConcatDataset): """ from pathlib import Path + # Validate that latent caching is enabled + # CDC requires latents to be cached (either to disk or in memory) because: + # 1. CDC files are named based on latent cache filenames + # 2. CDC files are saved next to latent cache files + # 3. Training needs latent paths to load corresponding CDC files + has_cached_latents = False + for dataset in self.datasets: + for info in dataset.image_data.values(): + if info.latents is not None or info.latents_npz is not None: + has_cached_latents = True + break + if has_cached_latents: + break + + if not has_cached_latents: + raise ValueError( + "CDC-FM requires latent caching to be enabled. " + "Please enable latent caching by setting one of:\n" + " - cache_latents = true (cache in memory)\n" + " - cache_latents_to_disk = true (cache to disk)\n" + "in your training config or command line arguments." + ) + # Collect dataset/subset directories for config hash dataset_dirs = [] for dataset in self.datasets: diff --git a/tests/library/test_cdc_cache_detection.py b/tests/library/test_cdc_cache_detection.py index c76af198..faba2058 100644 --- a/tests/library/test_cdc_cache_detection.py +++ b/tests/library/test_cdc_cache_detection.py @@ -243,6 +243,43 @@ def test_cdc_cache_detection_partial_cache(): assert result is False, "Should detect that some CDC cache files are missing" +def test_cdc_requires_latent_caching(): + """ + Test that CDC-FM gives a clear error when latent caching is not enabled. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + # Setup mock dataset with NO latent caching (both latents and latents_npz are None) + image_info = ImageInfo( + image_key="test_image", + num_repeats=1, + caption="test", + is_reg=False, + absolute_path=str(Path(tmpdir) / "image.png") + ) + image_info.latents_npz = None # No disk cache + image_info.latents = None # No memory cache + image_info.bucket_reso = (512, 512) + + mock_dataset = MockDataset({"test_image": image_info}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Attempt to cache CDC without latent caching enabled + with pytest.raises(ValueError) as exc_info: + dataset_group.cache_cdc_gamma_b( + k_neighbors=256, + k_bandwidth=8, + d_cdc=8, + gamma=1.0 + ) + + # Verify: Error message should mention latent caching requirement + error_message = str(exc_info.value) + assert "CDC-FM requires latent caching" in error_message + assert "cache_latents" in error_message + assert "cache_latents_to_disk" in error_message + + if __name__ == "__main__": # Run tests with verbose output pytest.main([__file__, "-v"])