Files
Kohya-ss-sd-scripts/tests/library/test_cdc_standalone.py
2025-10-30 23:27:13 -04:00

300 lines
12 KiB
Python

"""
Standalone tests for CDC-FM per-file caching.
These tests focus on the current CDC-FM per-file caching implementation
with hash-based cache validation.
"""
from pathlib import Path
import pytest
import torch
import numpy as np
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPreprocessor:
"""Test CDC preprocessing functionality with per-file caching"""
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
"""Test basic CDC preprocessing with small dataset"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save (creates per-file CDC caches)
files_saved = preprocessor.compute_all()
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_cdc_preprocessor_different_shapes(self, tmp_path):
"""Test CDC preprocessing with variable-size latents (bucketing)"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
assert files_saved == 10
# Check shapes are stored in individual files
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestGammaBDataset:
"""Test GammaBDataset loading and retrieval with per-file caching"""
@pytest.fixture
def sample_cdc_cache(self, tmp_path):
"""Create sample CDC cache files for testing"""
# Use 20 samples to ensure proper k-NN computation
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)],
adaptive_k=True, # Enable adaptive k for small dataset
min_bucket_size=5
)
# Create 20 samples
latents_npz_paths = []
for i in range(20):
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
preprocessor.compute_all()
return tmp_path, latents_npz_paths, preprocessor.config_hash
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
"""Test that GammaBDataset loads CDC files correctly"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get components for first sample
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigvecs.shape[0] == 1 # batch size
assert eigvecs.shape[1] == 4 # d_cdc
assert eigvals.shape == (1, 4) # batch, d_cdc
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
"""Test retrieving Γ_b^(1/2) components"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get Γ_b for paths [0, 2, 4]
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
latent_shape = (16, 8, 8)
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigenvectors.shape[0] == 3 # batch
assert eigenvectors.shape[1] == 4 # d_cdc
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
# Check values are positive
assert torch.all(eigenvalues > 0)
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns x unchanged at t=0"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Create test latents (batch of 3, matching d=1024 flattened)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.zeros(3) # t = 0 for all samples
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# At t=0, should return x unchanged
assert torch.allclose(sigma_t_x, x, atol=1e-6)
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns correct shape"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(2, 1024) # B, d (flattened)
t = torch.tensor([0.3, 0.7])
# Get Γ_b components
paths = [latents_npz_paths[1], latents_npz_paths[3]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should return same shape as input
assert sigma_t_x.shape == x.shape
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
"""Test compute_sigma_t_x produces finite values"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.rand(3) # Random timesteps in [0, 1]
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should not contain NaNs or Infs
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
class TestCDCEndToEnd:
"""End-to-end CDC workflow tests"""
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
"""Test complete workflow: preprocess -> save -> load -> use"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
if __name__ == "__main__":
pytest.main([__file__, "-v"])