mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add error if with CDC if cache_latents or cache_latents_to_disk is not set
This commit is contained in:
@@ -2736,6 +2736,29 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Validate that latent caching is enabled
|
||||
# CDC requires latents to be cached (either to disk or in memory) because:
|
||||
# 1. CDC files are named based on latent cache filenames
|
||||
# 2. CDC files are saved next to latent cache files
|
||||
# 3. Training needs latent paths to load corresponding CDC files
|
||||
has_cached_latents = False
|
||||
for dataset in self.datasets:
|
||||
for info in dataset.image_data.values():
|
||||
if info.latents is not None or info.latents_npz is not None:
|
||||
has_cached_latents = True
|
||||
break
|
||||
if has_cached_latents:
|
||||
break
|
||||
|
||||
if not has_cached_latents:
|
||||
raise ValueError(
|
||||
"CDC-FM requires latent caching to be enabled. "
|
||||
"Please enable latent caching by setting one of:\n"
|
||||
" - cache_latents = true (cache in memory)\n"
|
||||
" - cache_latents_to_disk = true (cache to disk)\n"
|
||||
"in your training config or command line arguments."
|
||||
)
|
||||
|
||||
# Collect dataset/subset directories for config hash
|
||||
dataset_dirs = []
|
||||
for dataset in self.datasets:
|
||||
|
||||
@@ -243,6 +243,43 @@ def test_cdc_cache_detection_partial_cache():
|
||||
assert result is False, "Should detect that some CDC cache files are missing"
|
||||
|
||||
|
||||
def test_cdc_requires_latent_caching():
|
||||
"""
|
||||
Test that CDC-FM gives a clear error when latent caching is not enabled.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Setup mock dataset with NO latent caching (both latents and latents_npz are None)
|
||||
image_info = ImageInfo(
|
||||
image_key="test_image",
|
||||
num_repeats=1,
|
||||
caption="test",
|
||||
is_reg=False,
|
||||
absolute_path=str(Path(tmpdir) / "image.png")
|
||||
)
|
||||
image_info.latents_npz = None # No disk cache
|
||||
image_info.latents = None # No memory cache
|
||||
image_info.bucket_reso = (512, 512)
|
||||
|
||||
mock_dataset = MockDataset({"test_image": image_info})
|
||||
dataset_group = DatasetGroup([mock_dataset])
|
||||
|
||||
# Test: Attempt to cache CDC without latent caching enabled
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dataset_group.cache_cdc_gamma_b(
|
||||
k_neighbors=256,
|
||||
k_bandwidth=8,
|
||||
d_cdc=8,
|
||||
gamma=1.0
|
||||
)
|
||||
|
||||
# Verify: Error message should mention latent caching requirement
|
||||
error_message = str(exc_info.value)
|
||||
assert "CDC-FM requires latent caching" in error_message
|
||||
assert "cache_latents" in error_message
|
||||
assert "cache_latents_to_disk" in error_message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests with verbose output
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user