Files
Kohya-ss-sd-scripts/tests/library/test_cdc_adaptive_k.py
rockerBOO 8089cb6925 Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
2025-10-11 17:17:09 -04:00

229 lines
7.7 KiB
Python

"""
Test adaptive k_neighbors functionality in CDC-FM.
Verifies that adaptive k properly adjusts based on bucket sizes.
"""
import pytest
import torch
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestAdaptiveK:
"""Test adaptive k_neighbors behavior"""
@pytest.fixture
def temp_cache_path(self, tmp_path):
"""Create temporary cache path"""
return tmp_path / "adaptive_k_test.safetensors"
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
"""
Test that fixed k mode skips buckets with < k_neighbors samples.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=False # Fixed mode
)
# Add 10 samples (< k=32, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (Gaussian fallback)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (fallback)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
"""
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 20 samples (< k=32, should use k=19)
shape = (4, 16, 16)
for i in range(20):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify non-zero (CDC computed)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should NOT be all zeros (CDC was computed)
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
"""
Test that adaptive k mode skips buckets below min_bucket_size.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=16
)
# Add 10 samples (< min_bucket_size=16, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (skipped due to min_bucket_size)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (skipped)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
"""
Test adaptive k with multiple buckets of different sizes.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Bucket 1: 10 samples (adaptive k=9)
for i in range(10):
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=(4, 16, 16),
metadata={'image_key': f'small_{i}'}
)
# Bucket 2: 40 samples (full k=32)
for i in range(40):
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=100+i,
shape=(4, 32, 32),
metadata={'image_key': f'large_{i}'}
)
# Bucket 3: 5 samples (< min=8, skipped)
for i in range(5):
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=200+i,
shape=(4, 8, 8),
metadata={'image_key': f'tiny_{i}'}
)
preprocessor.compute_all(temp_cache_path)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
# Bucket 1: Should have CDC (non-zero)
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
# Bucket 2: Should have CDC (non-zero)
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
# Bucket 3: Should be skipped (zeros)
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
"""
Test that adaptive k uses full k_neighbors when bucket is large enough.
"""
preprocessor = CDCPreprocessor(
k_neighbors=16,
k_bandwidth=4,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 50 samples (> k=16, should use full k=16)
shape = (4, 16, 16)
for i in range(50):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify CDC was computed
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should have non-zero eigenvalues
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
# Eigenvalues should be positive
assert (eigvals >= 0).all()
if __name__ == "__main__":
pytest.main([__file__, "-v"])