Add error if with CDC if cache_latents or cache_latents_to_disk is not set

This commit is contained in:
rockerBOO
2025-11-03 21:47:15 -05:00
parent 377299851a
commit 7a08c52aa4
2 changed files with 60 additions and 0 deletions

View File

@@ -243,6 +243,43 @@ def test_cdc_cache_detection_partial_cache():
assert result is False, "Should detect that some CDC cache files are missing"
def test_cdc_requires_latent_caching():
"""
Test that CDC-FM gives a clear error when latent caching is not enabled.
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Setup mock dataset with NO latent caching (both latents and latents_npz are None)
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = None # No disk cache
image_info.latents = None # No memory cache
image_info.bucket_reso = (512, 512)
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Attempt to cache CDC without latent caching enabled
with pytest.raises(ValueError) as exc_info:
dataset_group.cache_cdc_gamma_b(
k_neighbors=256,
k_bandwidth=8,
d_cdc=8,
gamma=1.0
)
# Verify: Error message should mention latent caching requirement
error_message = str(exc_info.value)
assert "CDC-FM requires latent caching" in error_message
assert "cache_latents" in error_message
assert "cache_latents_to_disk" in error_message
if __name__ == "__main__":
# Run tests with verbose output
pytest.main([__file__, "-v"])