Files
Kohya-ss-sd-scripts/tests/library/test_cdc_adaptive_k.py
rockerBOO 7ca799ca26 Add adaptive k_neighbors support for CDC-FM
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size
- Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16)
- Fixed mode (default): Skip buckets with < k_neighbors samples
- Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size
- Update CDCPreprocessor to support adaptive k per bucket
- Add metadata tracking for adaptive_k and min_bucket_size
- Add comprehensive pytest tests for adaptive k behavior

This allows CDC-FM to work effectively with multi-resolution bucketing where
bucket sizes may vary widely. Users can choose between strict paper methodology
(fixed k) or pragmatic approach (adaptive k).
2025-10-09 23:16:44 -04:00

231 lines
7.8 KiB
Python

"""
Test adaptive k_neighbors functionality in CDC-FM.
Verifies that adaptive k properly adjusts based on bucket sizes.
"""
import pytest
import torch
import numpy as np
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestAdaptiveK:
"""Test adaptive k_neighbors behavior"""
@pytest.fixture
def temp_cache_path(self, tmp_path):
"""Create temporary cache path"""
return tmp_path / "adaptive_k_test.safetensors"
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
"""
Test that fixed k mode skips buckets with < k_neighbors samples.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=False # Fixed mode
)
# Add 10 samples (< k=32, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (Gaussian fallback)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (fallback)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
"""
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 20 samples (< k=32, should use k=19)
shape = (4, 16, 16)
for i in range(20):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify non-zero (CDC computed)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should NOT be all zeros (CDC was computed)
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
"""
Test that adaptive k mode skips buckets below min_bucket_size.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=16
)
# Add 10 samples (< min_bucket_size=16, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (skipped due to min_bucket_size)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (skipped)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
"""
Test adaptive k with multiple buckets of different sizes.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Bucket 1: 10 samples (adaptive k=9)
for i in range(10):
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=(4, 16, 16),
metadata={'image_key': f'small_{i}'}
)
# Bucket 2: 40 samples (full k=32)
for i in range(40):
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=100+i,
shape=(4, 32, 32),
metadata={'image_key': f'large_{i}'}
)
# Bucket 3: 5 samples (< min=8, skipped)
for i in range(5):
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=200+i,
shape=(4, 8, 8),
metadata={'image_key': f'tiny_{i}'}
)
preprocessor.compute_all(temp_cache_path)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
# Bucket 1: Should have CDC (non-zero)
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
# Bucket 2: Should have CDC (non-zero)
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
# Bucket 3: Should be skipped (zeros)
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
"""
Test that adaptive k uses full k_neighbors when bucket is large enough.
"""
preprocessor = CDCPreprocessor(
k_neighbors=16,
k_bandwidth=4,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 50 samples (> k=16, should use full k=16)
shape = (4, 16, 16)
for i in range(50):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify CDC was computed
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should have non-zero eigenvalues
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
# Eigenvalues should be positive
assert (eigvals >= 0).all()
if __name__ == "__main__":
pytest.main([__file__, "-v"])