mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
231 lines
7.8 KiB
Python
231 lines
7.8 KiB
Python
"""
|
|
Test adaptive k_neighbors functionality in CDC-FM.
|
|
|
|
Verifies that adaptive k properly adjusts based on bucket sizes.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
|
|
|
|
|
class TestAdaptiveK:
|
|
"""Test adaptive k_neighbors behavior"""
|
|
|
|
@pytest.fixture
|
|
def temp_cache_path(self, tmp_path):
|
|
"""Create temporary cache path"""
|
|
return tmp_path / "adaptive_k_test.safetensors"
|
|
|
|
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
|
|
"""
|
|
Test that fixed k mode skips buckets with < k_neighbors samples.
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=32,
|
|
k_bandwidth=8,
|
|
d_cdc=4,
|
|
gamma=1.0,
|
|
device='cpu',
|
|
debug=False,
|
|
adaptive_k=False # Fixed mode
|
|
)
|
|
|
|
# Add 10 samples (< k=32, should be skipped)
|
|
shape = (4, 16, 16)
|
|
for i in range(10):
|
|
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=i,
|
|
shape=shape,
|
|
metadata={'image_key': f'test_{i}'}
|
|
)
|
|
|
|
preprocessor.compute_all(temp_cache_path)
|
|
|
|
# Load and verify zeros (Gaussian fallback)
|
|
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
|
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
|
|
|
# Should be all zeros (fallback)
|
|
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
|
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
|
|
|
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
|
|
"""
|
|
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=32,
|
|
k_bandwidth=8,
|
|
d_cdc=4,
|
|
gamma=1.0,
|
|
device='cpu',
|
|
debug=False,
|
|
adaptive_k=True,
|
|
min_bucket_size=8
|
|
)
|
|
|
|
# Add 20 samples (< k=32, should use k=19)
|
|
shape = (4, 16, 16)
|
|
for i in range(20):
|
|
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=i,
|
|
shape=shape,
|
|
metadata={'image_key': f'test_{i}'}
|
|
)
|
|
|
|
preprocessor.compute_all(temp_cache_path)
|
|
|
|
# Load and verify non-zero (CDC computed)
|
|
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
|
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
|
|
|
# Should NOT be all zeros (CDC was computed)
|
|
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
|
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
|
|
|
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
|
|
"""
|
|
Test that adaptive k mode skips buckets below min_bucket_size.
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=32,
|
|
k_bandwidth=8,
|
|
d_cdc=4,
|
|
gamma=1.0,
|
|
device='cpu',
|
|
debug=False,
|
|
adaptive_k=True,
|
|
min_bucket_size=16
|
|
)
|
|
|
|
# Add 10 samples (< min_bucket_size=16, should be skipped)
|
|
shape = (4, 16, 16)
|
|
for i in range(10):
|
|
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=i,
|
|
shape=shape,
|
|
metadata={'image_key': f'test_{i}'}
|
|
)
|
|
|
|
preprocessor.compute_all(temp_cache_path)
|
|
|
|
# Load and verify zeros (skipped due to min_bucket_size)
|
|
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
|
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
|
|
|
# Should be all zeros (skipped)
|
|
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
|
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
|
|
|
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
|
|
"""
|
|
Test adaptive k with multiple buckets of different sizes.
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=32,
|
|
k_bandwidth=8,
|
|
d_cdc=4,
|
|
gamma=1.0,
|
|
device='cpu',
|
|
debug=False,
|
|
adaptive_k=True,
|
|
min_bucket_size=8
|
|
)
|
|
|
|
# Bucket 1: 10 samples (adaptive k=9)
|
|
for i in range(10):
|
|
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=i,
|
|
shape=(4, 16, 16),
|
|
metadata={'image_key': f'small_{i}'}
|
|
)
|
|
|
|
# Bucket 2: 40 samples (full k=32)
|
|
for i in range(40):
|
|
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=100+i,
|
|
shape=(4, 32, 32),
|
|
metadata={'image_key': f'large_{i}'}
|
|
)
|
|
|
|
# Bucket 3: 5 samples (< min=8, skipped)
|
|
for i in range(5):
|
|
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=200+i,
|
|
shape=(4, 8, 8),
|
|
metadata={'image_key': f'tiny_{i}'}
|
|
)
|
|
|
|
preprocessor.compute_all(temp_cache_path)
|
|
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
|
|
|
# Bucket 1: Should have CDC (non-zero)
|
|
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
|
|
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
|
|
|
|
# Bucket 2: Should have CDC (non-zero)
|
|
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
|
|
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
|
|
|
|
# Bucket 3: Should be skipped (zeros)
|
|
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
|
|
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
|
|
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
|
|
|
|
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
|
|
"""
|
|
Test that adaptive k uses full k_neighbors when bucket is large enough.
|
|
"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=16,
|
|
k_bandwidth=4,
|
|
d_cdc=4,
|
|
gamma=1.0,
|
|
device='cpu',
|
|
debug=False,
|
|
adaptive_k=True,
|
|
min_bucket_size=8
|
|
)
|
|
|
|
# Add 50 samples (> k=16, should use full k=16)
|
|
shape = (4, 16, 16)
|
|
for i in range(50):
|
|
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
|
preprocessor.add_latent(
|
|
latent=latent,
|
|
global_idx=i,
|
|
shape=shape,
|
|
metadata={'image_key': f'test_{i}'}
|
|
)
|
|
|
|
preprocessor.compute_all(temp_cache_path)
|
|
|
|
# Load and verify CDC was computed
|
|
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
|
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
|
|
|
# Should have non-zero eigenvalues
|
|
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
|
# Eigenvalues should be positive
|
|
assert (eigvals >= 0).all()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|