Fix CDC tests to new format and deprecate old tests

This commit is contained in:
rockerBOO
2025-10-18 14:35:49 -04:00
parent 83c17de61f
commit c820acee58
15 changed files with 318 additions and 2830 deletions

View File

@@ -1,228 +0,0 @@
"""
Test adaptive k_neighbors functionality in CDC-FM.
Verifies that adaptive k properly adjusts based on bucket sizes.
"""
import pytest
import torch
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestAdaptiveK:
"""Test adaptive k_neighbors behavior"""
@pytest.fixture
def temp_cache_path(self, tmp_path):
"""Create temporary cache path"""
return tmp_path / "adaptive_k_test.safetensors"
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
"""
Test that fixed k mode skips buckets with < k_neighbors samples.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=False # Fixed mode
)
# Add 10 samples (< k=32, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (Gaussian fallback)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (fallback)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
"""
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 20 samples (< k=32, should use k=19)
shape = (4, 16, 16)
for i in range(20):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify non-zero (CDC computed)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should NOT be all zeros (CDC was computed)
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
"""
Test that adaptive k mode skips buckets below min_bucket_size.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=16
)
# Add 10 samples (< min_bucket_size=16, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (skipped due to min_bucket_size)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (skipped)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
"""
Test adaptive k with multiple buckets of different sizes.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Bucket 1: 10 samples (adaptive k=9)
for i in range(10):
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=(4, 16, 16),
metadata={'image_key': f'small_{i}'}
)
# Bucket 2: 40 samples (full k=32)
for i in range(40):
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=100+i,
shape=(4, 32, 32),
metadata={'image_key': f'large_{i}'}
)
# Bucket 3: 5 samples (< min=8, skipped)
for i in range(5):
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=200+i,
shape=(4, 8, 8),
metadata={'image_key': f'tiny_{i}'}
)
preprocessor.compute_all(temp_cache_path)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
# Bucket 1: Should have CDC (non-zero)
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
# Bucket 2: Should have CDC (non-zero)
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
# Bucket 3: Should be skipped (zeros)
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
"""
Test that adaptive k uses full k_neighbors when bucket is large enough.
"""
preprocessor = CDCPreprocessor(
k_neighbors=16,
k_bandwidth=4,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 50 samples (> k=16, should use full k=16)
shape = (4, 16, 16)
for i in range(50):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify CDC was computed
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should have non-zero eigenvalues
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
# Eigenvalues should be positive
assert (eigvals >= 0).all()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,132 +0,0 @@
"""
Test device consistency handling in CDC noise transformation.
Ensures that device mismatches are handled gracefully.
"""
import pytest
import torch
import logging
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestDeviceConsistency:
"""Test device consistency validation"""
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_device.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
def test_matching_devices_no_warning(self, cdc_cache, caplog):
"""
Test that no warnings are emitted when devices match.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
shape = (16, 32, 32)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
with caplog.at_level(logging.WARNING):
caplog.clear()
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# No device mismatch warnings
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
assert len(device_warnings) == 0, "Should not warn when devices match"
def test_device_mismatch_warning_and_transfer(self, cdc_cache, caplog):
"""
Test that device mismatch is detected, warned, and handled.
This simulates the case where noise is on one device but CDC matrices
are requested for another device.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
shape = (16, 32, 32)
# Create noise on CPU
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
# But request CDC matrices for a different device string
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
with caplog.at_level(logging.WARNING):
caplog.clear()
# Use a different device specification to trigger the check
# We'll use "cpu" vs "cpu:0" as an example of string mismatch
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu" # Same actual device, consistent string
)
# Should complete without errors
assert result is not None
assert result.shape == noise.shape
def test_transformation_works_after_device_transfer(self, cdc_cache):
"""
Test that CDC transformation produces valid output even if devices differ.
The function should handle device transfer gracefully.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
shape = (16, 32, 32)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Verify output is valid
assert result.shape == noise.shape
assert result.device == noise.device
assert result.requires_grad # Gradients should still work
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()
# Verify gradients flow
loss = result.sum()
loss.backward()
assert noise.grad is not None
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,146 +0,0 @@
"""
Test CDC-FM dimension handling and fallback mechanisms.
This module tests the behavior of the CDC Flow Matching implementation
when encountering latents with different dimensions.
"""
import torch
import logging
import tempfile
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestDimensionHandling:
def setup_method(self):
"""Prepare consistent test environment"""
self.logger = logging.getLogger(__name__)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mixed_dimension_fallback(self):
"""
Verify that preprocessor falls back to standard noise for mixed-dimension batches
"""
# Prepare preprocessor with debug mode
preprocessor = CDCPreprocessor(debug=True)
# Different-sized latents (3D: channels, height, width)
latents = [
torch.randn(3, 32, 64), # First latent: 3x32x64
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
# Try adding mixed-dimension latents
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_mixed_image_{i}'}
)
try:
cdc_path = preprocessor.compute_all(tmp_file.name)
except ValueError as e:
# If implementation raises ValueError, that's acceptable
assert "Dimension mismatch" in str(e)
return
# Check for dimension-related log messages
dimension_warnings = [
msg for msg in log_messages
if "dimension mismatch" in msg.lower()
]
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
# Load results and verify fallback
dataset = GammaBDataset(cdc_path)
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
# Check metadata about samples with/without CDC
assert dataset.num_samples == len(latents), "All samples should be processed"
def test_adaptive_k_with_dimension_constraints(self):
"""
Test adaptive k-neighbors behavior with dimension constraints
"""
# Prepare preprocessor with adaptive k and small bucket size
preprocessor = CDCPreprocessor(
adaptive_k=True,
min_bucket_size=5,
debug=True
)
# Generate latents with similar but not identical dimensions
base_latent = torch.randn(3, 32, 64)
similar_latents = [
base_latent,
torch.randn(3, 32, 65), # Slightly different dimension
torch.randn(3, 32, 66) # Another slightly different dimension
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add similar latents
for i, latent in enumerate(similar_latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_adaptive_k_image_{i}'}
)
cdc_path = preprocessor.compute_all(tmp_file.name)
# Load results
dataset = GammaBDataset(cdc_path)
# Verify samples processed
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
# Optional: Check warnings about dimension differences
dimension_warnings = [
msg for msg in log_messages
if "dimension" in msg.lower()
]
print(f"Dimension-related warnings: {dimension_warnings}")
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
def pytest_configure(config):
"""
Configure custom markers for dimension handling tests
"""
config.addinivalue_line(
"markers",
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
)

View File

@@ -1,310 +0,0 @@
"""
Comprehensive CDC Dimension Handling and Warning Tests
This module tests:
1. Dimension mismatch detection and fallback mechanisms
2. Warning throttling for shape mismatches
3. Adaptive k-neighbors behavior with dimension constraints
"""
import pytest
import torch
import logging
import tempfile
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
class TestDimensionHandlingAndWarnings:
"""
Comprehensive testing of dimension handling, noise injection, and warning systems
"""
@pytest.fixture(autouse=True)
def clear_warned_samples(self):
"""Clear the warned samples set before each test"""
_cdc_warned_samples.clear()
yield
_cdc_warned_samples.clear()
def test_mixed_dimension_fallback(self):
"""
Verify that preprocessor falls back to standard noise for mixed-dimension batches
"""
# Prepare preprocessor with debug mode
preprocessor = CDCPreprocessor(debug=True)
# Different-sized latents (3D: channels, height, width)
latents = [
torch.randn(3, 32, 64), # First latent: 3x32x64
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
# Try adding mixed-dimension latents
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_mixed_image_{i}'}
)
try:
cdc_path = preprocessor.compute_all(tmp_file.name)
except ValueError as e:
# If implementation raises ValueError, that's acceptable
assert "Dimension mismatch" in str(e)
return
# Check for dimension-related log messages
dimension_warnings = [
msg for msg in log_messages
if "dimension mismatch" in msg.lower()
]
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
# Load results and verify fallback
dataset = GammaBDataset(cdc_path)
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
# Check metadata about samples with/without CDC
assert dataset.num_samples == len(latents), "All samples should be processed"
def test_adaptive_k_with_dimension_constraints(self):
"""
Test adaptive k-neighbors behavior with dimension constraints
"""
# Prepare preprocessor with adaptive k and small bucket size
preprocessor = CDCPreprocessor(
adaptive_k=True,
min_bucket_size=5,
debug=True
)
# Generate latents with similar but not identical dimensions
base_latent = torch.randn(3, 32, 64)
similar_latents = [
base_latent,
torch.randn(3, 32, 65), # Slightly different dimension
torch.randn(3, 32, 66) # Another slightly different dimension
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add similar latents
for i, latent in enumerate(similar_latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_adaptive_k_image_{i}'}
)
cdc_path = preprocessor.compute_all(tmp_file.name)
# Load results
dataset = GammaBDataset(cdc_path)
# Verify samples processed
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
# Optional: Check warnings about dimension differences
dimension_warnings = [
msg for msg in log_messages
if "dimension" in msg.lower()
]
print(f"Dimension-related warnings: {dimension_warnings}")
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
def test_warning_only_logged_once_per_sample(self, caplog):
"""
Test that shape mismatch warning is only logged once per sample.
Even if the same sample appears in multiple batches, only warn once.
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with one specific shape
preprocessed_shape = (16, 32, 32)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Use different shape at runtime to trigger mismatch
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0'] # Same sample
# First call - should warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise1,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have exactly one warning
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 1, "First call should produce exactly one warning"
assert "CDC shape mismatch" in warnings[0].message
# Second call with same sample - should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise2,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Second call with same sample should not warn"
def test_different_samples_each_get_one_warning(self, caplog):
"""
Test that different samples each get their own warning.
Each unique sample should be warned about once.
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with specific shape
preprocessed_shape = (16, 32, 32)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
# First batch: samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 3 warnings (one per sample)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 3, "Should warn for each of the 3 samples"
# Second batch: same samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings (already warned)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Should not warn again for same samples"
# Third batch: new samples 3, 4
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_3', 'test_image_4']
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 2 warnings (new samples)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
def pytest_configure(config):
"""
Configure custom markers for dimension handling and warning tests
"""
config.addinivalue_line(
"markers",
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
)
config.addinivalue_line(
"markers",
"warning_throttling: mark test for CDC-FM warning suppression"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,164 +0,0 @@
"""
Tests using realistic high-dimensional data to catch scaling bugs.
This test uses realistic VAE-like latents to ensure eigenvalue normalization
works correctly on real-world data.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestRealisticDataScaling:
"""Test eigenvalue scaling with realistic high-dimensional data"""
def test_high_dimensional_latents_not_saturated(self, tmp_path):
"""
Verify that high-dimensional realistic latents don't saturate eigenvalues.
This test simulates real FLUX training data:
- High dimension (16×64×64 = 65536)
- Varied content (different variance in different regions)
- Realistic magnitude (VAE output scale)
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
# Different channels have different patterns (realistic for VAE)
for c in range(16):
# Some channels have gradients
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
# Some channels have patterns
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
# Some channels are more uniform
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation (different "subjects")
latent = latent * (1.0 + i * 0.2)
# Add realistic VAE-like noise/variation
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are NOT all saturated at 1.0
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical: eigenvalues should NOT all be 1.0
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# FAIL if too many eigenvalues are saturated at 1.0
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"This indicates the normalization bug - raw eigenvalues are not being "
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same value."
)
# Mean should be in reasonable range (not all 1.0)
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
"""
Test that datasets with more variance produce more diverse eigenvalues.
This ensures the normalization preserves relative information.
"""
# Create two preprocessors with different data variance
results = {}
for variance_scale in [0.5, 2.0]:
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(15):
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
# Create varied patterns
for c in range(16):
for h in range(32):
for w in range(32):
latent[c, h, w] = (
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = []
for i in range(15):
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
eigvals.extend(ev[ev > 1e-6])
results[variance_scale] = {
'mean': np.mean(eigvals),
'std': np.std(eigvals),
'range': (np.min(eigvals), np.max(eigvals))
}
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
# Both should have diversity (not saturated)
for scale in [0.5, 2.0]:
assert results[scale]['std'] > 0.1, (
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,252 +0,0 @@
"""
Tests to verify CDC eigenvalue scaling is correct.
These tests ensure eigenvalues are properly scaled to prevent training loss explosion.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestEigenvalueScaling:
"""Test that eigenvalues are properly scaled to reasonable ranges"""
def test_eigenvalues_in_correct_range(self, tmp_path):
"""Verify eigenvalues are scaled to ~0.01-1.0 range, not millions"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Add deterministic latents with structured patterns
for i in range(10):
# Create gradient pattern: values from 0 to 2.0 across spatial dims
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
# Add per-sample variation
latent = latent + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are in correct range
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
# Filter out zero eigenvalues (from padding when k < d_cdc)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical assertions for eigenvalue scale
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
# Check sqrt (used in noise) is reasonable
sqrt_max = np.sqrt(all_eigvals.max())
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
print(f"✓ sqrt(max): {sqrt_max:.4f}")
def test_eigenvalues_not_all_zero(self, tmp_path):
"""Ensure eigenvalues are not all zero (indicating computation failure)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# With clamping, eigenvalues will be in range [1e-3, gamma*1.0]
# Check that we have some non-zero eigenvalues
assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed"
# Check they're in the expected clamped range
assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}"
assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}"
print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
def test_fp16_storage_no_overflow(self, tmp_path):
"""Verify fp16 storage doesn't overflow (max fp16 = 65,504)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern with higher magnitude
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
latent = latent + i * 0.3
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check dtype is fp16
eigvecs = f.get_tensor("eigenvectors/test_image_0")
eigvals = f.get_tensor("eigenvalues/test_image_0")
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
# Check no values near fp16 max (would indicate overflow)
FP16_MAX = 65504
max_eigval = eigvals.max().item()
assert max_eigval < 100, (
f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. "
f"May indicate overflow (fp16 max = {FP16_MAX})"
)
print(f"\n✓ Storage dtype: {eigvals.dtype}")
print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)")
def test_latent_magnitude_preserved(self, tmp_path):
"""Verify latent magnitude is preserved (no unwanted normalization)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Store original latents with deterministic patterns
original_latents = []
for i in range(10):
# Create structured pattern with known magnitude
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
original_latents.append(latent.clone())
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Compute original latent statistics
orig_std = torch.stack(original_latents).std().item()
output_path = tmp_path / "test_gamma_b.safetensors"
preprocessor.compute_all(save_path=output_path)
# The stored latents should preserve original magnitude
stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples])
# Should be similar to original (within 20% due to potential batching effects)
assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, (
f"Stored latent std {stored_latents_std:.2f} differs too much from "
f"original {orig_std:.2f}. Latent magnitude was not preserved."
)
print(f"\n✓ Original latent std: {orig_std:.2f}")
print(f"✓ Stored latent std: {stored_latents_std:.2f}")
class TestTrainingLossScale:
"""Test that eigenvalues produce reasonable loss magnitudes"""
def test_noise_magnitude_reasonable(self, tmp_path):
"""Verify CDC noise has reasonable magnitude for training"""
from library.cdc_fm import GammaBDataset
# Create CDC cache with deterministic data
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Load and compute noise
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Simulate training scenario with deterministic data
batch_size = 3
latents = torch.zeros(batch_size, 16, 4, 4)
for b in range(batch_size):
for c in range(16):
for h in range(4):
for w in range(4):
latents[b, c, h, w] = (b + c + h + w) / 24.0
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
# Check noise magnitude
noise_std = noise.std().item()
latent_std = latents.std().item()
# Noise should be similar magnitude to input latents (within 10x)
ratio = noise_std / latent_std
assert 0.1 < ratio < 10.0, (
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
)
# Simulated MSE loss should be reasonable
simulated_loss = torch.mean((noise - latents) ** 2).item()
assert simulated_loss < 100.0, (
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
f"Should be O(0.1-1.0) for stable training."
)
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,220 +0,0 @@
"""
Comprehensive CDC Eigenvalue Validation Tests
These tests ensure that eigenvalue computation and scaling work correctly
across various scenarios, including:
- Scaling to reasonable ranges
- Handling high-dimensional data
- Preserving latent information
- Preventing computational artifacts
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestEigenvalueScaling:
"""Verify eigenvalue scaling and computational properties"""
def test_eigenvalues_in_correct_range(self, tmp_path):
"""
Verify eigenvalues are scaled to ~0.01-1.0 range, not millions.
Ensures:
- No numerical explosions
- Reasonable eigenvalue magnitudes
- Consistent scaling across samples
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create deterministic latents with structured patterns
for i in range(10):
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
latent = latent + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are in correct range
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical assertions for eigenvalue scale
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
# Check sqrt (used in noise) is reasonable
sqrt_max = np.sqrt(all_eigvals.max())
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
print(f"✓ sqrt(max): {sqrt_max:.4f}")
def test_high_dimensional_latents_scaling(self, tmp_path):
"""
Verify scaling for high-dimensional realistic latents.
Key scenarios:
- High-dimensional data (16×64×64)
- Varied channel structures
- Realistic VAE-like data
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
for c in range(16):
# Different patterns across channels
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation
latent = latent * (1.0 + i * 0.2)
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are not all saturated
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# Fail if too many eigenvalues are saturated
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"Raw eigenvalues not scaled before clamping. "
f"Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same."
)
# Mean should be in reasonable range
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_noise_magnitude_reasonable(self, tmp_path):
"""
Verify CDC noise has reasonable magnitude for training.
Ensures noise:
- Has similar scale to input latents
- Won't destabilize training
- Preserves input variance
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Load and compute noise
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Simulate training scenario with deterministic data
batch_size = 3
latents = torch.zeros(batch_size, 16, 4, 4)
for b in range(batch_size):
for c in range(16):
for h in range(4):
for w in range(4):
latents[b, c, h, w] = (b + c + h + w) / 24.0
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
# Check noise magnitude
noise_std = noise.std().item()
latent_std = latents.std().item()
# Noise should be similar magnitude to input latents (within 10x)
ratio = noise_std / latent_std
assert 0.1 < ratio < 10.0, (
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
)
# Simulated MSE loss should be reasonable
simulated_loss = torch.mean((noise - latents) ** 2).item()
assert simulated_loss < 100.0, (
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
f"Should be O(0.1-1.0) for stable training."
)
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,297 +0,0 @@
"""
CDC Gradient Flow Verification Tests
This module provides testing of:
1. Mock dataset gradient preservation
2. Real dataset gradient flow
3. Various time steps and computation paths
4. Fallback and edge case scenarios
"""
import pytest
import torch
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCGradientFlow:
"""
Gradient flow testing for CDC noise transformations
"""
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mock_gradient_flow_near_zero_time_step(self):
"""
Verify gradient flow preservation for near-zero time steps
using mock dataset with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
def test_gradient_flow_with_multiple_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
t.grad.zero_() if t.grad is not None else None
latent.grad.zero_() if latent.grad is not None else None
def test_gradient_flow_with_real_dataset(self, tmp_path):
"""
Test gradient flow with real CDC dataset
"""
# Create cache with uniform shapes
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Prepare test noise
torch.manual_seed(42)
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
# Apply CDC transformation
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Verify gradient flow
assert noise_out.requires_grad, "Output should require gradients"
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None, "Gradients should flow back to input noise"
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
def test_gradient_flow_with_fallback(self, tmp_path):
"""
Test gradient flow when using Gaussian fallback (shape mismatch)
Ensures that cloned tensors maintain gradient flow correctly
even when shape mismatch triggers Gaussian noise
"""
# Create cache with one shape
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
preprocessed_shape = (16, 32, 32)
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': 'test_image_0'}
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
cache_path = tmp_path / "test_fallback_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Use different shape at runtime (will trigger fallback)
runtime_shape = (16, 64, 64)
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0']
# Apply transformation (should fallback to Gaussian for this sample)
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Ensure gradients still flow through fallback path
assert noise_out.requires_grad, "Fallback output should require gradients"
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None, "Gradients should flow even in fallback case"
assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN"
def pytest_configure(config):
"""
Configure custom markers for CDC gradient flow tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)
config.addinivalue_line(
"markers",
"mock_dataset: mark test using mock dataset for simplified gradient testing"
)
config.addinivalue_line(
"markers",
"real_dataset: mark test using real dataset for comprehensive gradient testing"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,157 @@
"""
Test CDC config hash generation and cache invalidation
"""
import pytest
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor
class TestCDCConfigHash:
"""
Test that CDC config hash properly invalidates cache when dataset or parameters change
"""
def test_same_config_produces_same_hash(self, tmp_path):
"""
Test that identical configurations produce identical hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_different_dataset_dirs_produce_different_hash(self, tmp_path):
"""
Test that different dataset directories produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset2")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_k_neighbors_produces_different_hash(self, tmp_path):
"""
Test that different k_neighbors values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_d_cdc_produces_different_hash(self, tmp_path):
"""
Test that different d_cdc values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_gamma_produces_different_hash(self, tmp_path):
"""
Test that different gamma values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_multiple_dataset_dirs_order_independent(self, tmp_path):
"""
Test that dataset directory order doesn't affect hash (they are sorted)
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_hash_length_is_8_chars(self, tmp_path):
"""
Test that hash is exactly 8 characters (hex)
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert len(preprocessor.config_hash) == 8
# Verify it's hex
int(preprocessor.config_hash, 16) # Should not raise
def test_filename_includes_hash(self, tmp_path):
"""
Test that CDC filenames include the config hash
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash)
# Should be: image_0512x0768_flux_cdc_<hash>.npz
expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz")
assert cdc_path == expected
def test_backward_compatibility_no_hash(self, tmp_path):
"""
Test that get_cdc_npz_path works without hash (backward compatibility)
"""
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None)
# Should be: image_0512x0768_flux_cdc.npz (no hash suffix)
expected = str(tmp_path / "image_0512x0768_flux_cdc.npz")
assert cdc_path == expected
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,163 +0,0 @@
"""
Test comparing interpolation vs pad/truncate for CDC preprocessing.
This test quantifies the difference between the two approaches.
"""
import pytest
import torch
import torch.nn.functional as F
class TestInterpolationComparison:
"""Compare interpolation vs pad/truncate"""
def test_intermediate_representation_quality(self):
"""Compare intermediate representation quality for CDC computation"""
# Create test latents with different sizes - deterministic
latent_small = torch.zeros(16, 4, 4)
for c in range(16):
for h in range(4):
for w in range(4):
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
latent_large = torch.zeros(16, 8, 8)
for c in range(16):
for h in range(8):
for w in range(8):
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
target_h, target_w = 6, 6 # Median size
# Method 1: Interpolation
def interpolate_method(latent, target_h, target_w):
latent_input = latent.unsqueeze(0) # (1, C, H, W)
latent_resized = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
)
# Resize back
C, H, W = latent.shape
latent_reconstructed = F.interpolate(
latent_resized, size=(H, W), mode='bilinear', align_corners=False
)
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
return relative_error
# Method 2: Pad/Truncate
def pad_truncate_method(latent, target_h, target_w):
C, H, W = latent.shape
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
current_dim = C * H * W
if current_dim == target_dim:
latent_resized_flat = latent_flat
elif current_dim > target_dim:
# Truncate
latent_resized_flat = latent_flat[:target_dim]
else:
# Pad
latent_resized_flat = torch.zeros(target_dim)
latent_resized_flat[:current_dim] = latent_flat
# Resize back
if current_dim == target_dim:
latent_reconstructed_flat = latent_resized_flat
elif current_dim > target_dim:
# Pad back
latent_reconstructed_flat = torch.zeros(current_dim)
latent_reconstructed_flat[:target_dim] = latent_resized_flat
else:
# Truncate back
latent_reconstructed_flat = latent_resized_flat[:current_dim]
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
return relative_error
# Compare for small latent (needs padding)
interp_error_small = interpolate_method(latent_small, target_h, target_w)
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
# Compare for large latent (needs truncation)
interp_error_large = interpolate_method(latent_large, target_h, target_w)
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(" BUT the intermediate representation is corrupted with zeros!")
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
# The key insight: Reconstruction error is NOT what matters for CDC!
# What matters is the INTERMEDIATE representation quality used for geometry estimation.
# Pad/truncate may have good reconstruction, but the intermediate is corrupted.
print("\nKey insight: For CDC, intermediate representation quality matters,")
print("not reconstruction error. Interpolation preserves spatial structure.")
# Verify interpolation errors are reasonable
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
def test_spatial_structure_preservation(self):
"""Test that interpolation preserves spatial structure better than pad/truncate"""
# Create a latent with clear spatial pattern (gradient)
C, H, W = 16, 4, 4
latent = torch.zeros(C, H, W)
for i in range(H):
for j in range(W):
latent[:, i, j] = i * W + j # Gradient pattern
target_h, target_w = 6, 6
# Interpolation
latent_input = latent.unsqueeze(0)
latent_interp = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
).squeeze(0)
# Pad/truncate
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
latent_padded = torch.zeros(target_dim)
latent_padded[:len(latent_flat)] = latent_flat
latent_pad = latent_padded.reshape(C, target_h, target_w)
# Check gradient preservation
# For interpolation, adjacent pixels should have smooth gradients
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
# For padding, there will be abrupt changes (gradient to zero)
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print("\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
# Padding introduces larger gradients due to abrupt zeros
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,412 +0,0 @@
"""
Performance and Interpolation Tests for CDC Flow Matching
This module provides testing of:
1. Computational overhead
2. Noise injection properties
3. Interpolation vs. pad/truncate methods
4. Spatial structure preservation
"""
import pytest
import torch
import time
import tempfile
import numpy as np
import torch.nn.functional as F
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPerformanceAndInterpolation:
"""
Comprehensive performance testing for CDC Flow Matching
Covers computational efficiency, noise properties, and interpolation quality
"""
@pytest.fixture(params=[
(3, 32, 32), # Small latent: typical for compact representations
(3, 64, 64), # Medium latent: standard feature maps
(3, 128, 128) # Large latent: high-resolution feature spaces
])
def latent_sizes(self, request):
"""
Parametrized fixture generating test cases for different latent sizes.
Rationale:
- Tests robustness across various computational scales
- Ensures consistent behavior from compact to large representations
- Identifies potential dimensionality-related performance bottlenecks
"""
return request.param
def test_computational_overhead(self, latent_sizes):
"""
Measure computational overhead of CDC preprocessing across latent sizes.
Performance Verification Objectives:
1. Verify preprocessing time scales predictably with input dimensions
2. Ensure adaptive k-neighbors works efficiently
3. Validate computational overhead remains within acceptable bounds
Performance Metrics:
- Total preprocessing time
- Per-sample processing time
- Computational complexity indicators
"""
# Tuned preprocessing configuration
preprocessor = CDCPreprocessor(
k_neighbors=256, # Comprehensive neighborhood exploration
d_cdc=8, # Geometric embedding dimensionality
debug=True, # Enable detailed performance logging
adaptive_k=True # Dynamic neighborhood size adjustment
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42) # Consistent random generation
# Generate representative latent batch
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
# Precision timing of preprocessing
start_time = time.perf_counter()
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with traceable metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'perf_test_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Calculate precise preprocessing metrics
end_time = time.perf_counter()
preprocessing_time = end_time - start_time
per_sample_time = preprocessing_time / batch_size
# Performance reporting and assertions
input_volume = np.prod(latent_sizes)
time_complexity_indicator = preprocessing_time / input_volume
print(f"\nPerformance Breakdown:")
print(f" Latent Size: {latent_sizes}")
print(f" Total Samples: {batch_size}")
print(f" Input Volume: {input_volume}")
print(f" Total Time: {preprocessing_time:.4f} seconds")
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
# Adaptive thresholds based on input dimensions
max_total_time = 10.0 # Base threshold
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
# Different time complexity thresholds for different latent sizes
max_time_complexity = (
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
1e-4 # Standard latents
)
# Performance assertions with informative error messages
assert preprocessing_time < max_total_time, (
f"Total preprocessing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Total Time: {preprocessing_time:.4f} seconds\n"
f" Threshold: {max_total_time} seconds"
)
assert per_sample_time < max_per_sample_time, (
f"Per-sample processing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
f" Threshold: {max_per_sample_time} seconds"
)
# More adaptable time complexity check
assert time_complexity_indicator < max_time_complexity, (
f"Time complexity scaling exceeded expectations!\n"
f" Latent Size: {latent_sizes}\n"
f" Input Volume: {input_volume}\n"
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
f" Threshold: {max_time_complexity} seconds/voxel"
)
def test_noise_distribution(self, latent_sizes):
"""
Verify CDC noise injection quality and properties.
Based on test plan objectives:
1. CDC noise is actually being generated (not all Gaussian fallback)
2. Eigenvalues are valid (non-negative, bounded)
3. CDC components are finite and usable for noise generation
"""
preprocessor = CDCPreprocessor(
k_neighbors=16, # Reduced to match batch size
d_cdc=8,
gamma=1.0,
debug=True,
adaptive_k=True
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42)
# Generate batch of latents
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'noise_dist_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Analyze noise properties
dataset = GammaBDataset(cdc_path)
# Track samples that used CDC vs Gaussian fallback
cdc_samples = 0
gaussian_samples = 0
eigenvalue_stats = {
'min': float('inf'),
'max': float('-inf'),
'mean': 0.0,
'sum': 0.0
}
# Verify each sample's CDC components
for i in range(batch_size):
image_key = f'noise_dist_image_{i}'
# Get eigenvectors and eigenvalues
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
# Skip zero eigenvectors (fallback case)
if torch.all(eigvecs[0] == 0):
gaussian_samples += 1
continue
# Get the top d_cdc eigenvectors and eigenvalues
top_eigvecs = eigvecs[0] # (d_cdc, d)
top_eigvals = eigvals[0] # (d_cdc,)
# Basic validity checks
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
# Update statistics
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
eigenvalue_stats['sum'] += top_eigvals.sum().item()
cdc_samples += 1
# Compute mean eigenvalue across all CDC samples
if cdc_samples > 0:
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
# Print final statistics
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
print(f" CDC samples: {cdc_samples}/{batch_size}")
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
# Assertions based on plan objectives
# 1. CDC noise should be generated for most samples
assert cdc_samples > 0, "No samples used CDC noise injection"
assert gaussian_samples < batch_size // 2, (
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
)
# 2. Eigenvalues should be valid (non-negative and bounded)
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
# 3. Mean eigenvalue should be reasonable (not degenerate)
assert eigenvalue_stats['mean'] > 0.05, (
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
"suggests degenerate CDC components"
)
def test_interpolation_reconstruction(self):
"""
Compare interpolation vs pad/truncate reconstruction methods for CDC.
"""
# Create test latents with different sizes - deterministic
latent_small = torch.zeros(16, 4, 4)
for c in range(16):
for h in range(4):
for w in range(4):
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
latent_large = torch.zeros(16, 8, 8)
for c in range(16):
for h in range(8):
for w in range(8):
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
target_h, target_w = 6, 6 # Median size
# Method 1: Interpolation
def interpolate_method(latent, target_h, target_w):
latent_input = latent.unsqueeze(0) # (1, C, H, W)
latent_resized = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
)
# Resize back
C, H, W = latent.shape
latent_reconstructed = F.interpolate(
latent_resized, size=(H, W), mode='bilinear', align_corners=False
)
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
return relative_error
# Method 2: Pad/Truncate
def pad_truncate_method(latent, target_h, target_w):
C, H, W = latent.shape
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
current_dim = C * H * W
if current_dim == target_dim:
latent_resized_flat = latent_flat
elif current_dim > target_dim:
# Truncate
latent_resized_flat = latent_flat[:target_dim]
else:
# Pad
latent_resized_flat = torch.zeros(target_dim)
latent_resized_flat[:current_dim] = latent_flat
# Resize back
if current_dim == target_dim:
latent_reconstructed_flat = latent_resized_flat
elif current_dim > target_dim:
# Pad back
latent_reconstructed_flat = torch.zeros(current_dim)
latent_reconstructed_flat[:target_dim] = latent_resized_flat
else:
# Truncate back
latent_reconstructed_flat = latent_resized_flat[:current_dim]
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
return relative_error
# Compare for small latent (needs padding)
interp_error_small = interpolate_method(latent_small, target_h, target_w)
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
# Compare for large latent (needs truncation)
interp_error_large = interpolate_method(latent_large, target_h, target_w)
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(" BUT the intermediate representation is corrupted with zeros!")
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
print("\nKey insight: For CDC, intermediate representation quality matters,")
print("not reconstruction error. Interpolation preserves spatial structure.")
# Verify interpolation errors are reasonable
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
def test_spatial_structure_preservation(self):
"""
Test that interpolation preserves spatial structure better than pad/truncate.
"""
# Create a latent with clear spatial pattern (gradient)
C, H, W = 16, 4, 4
latent = torch.zeros(C, H, W)
for i in range(H):
for j in range(W):
latent[:, i, j] = i * W + j # Gradient pattern
target_h, target_w = 6, 6
# Interpolation
latent_input = latent.unsqueeze(0)
latent_interp = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
).squeeze(0)
# Pad/truncate
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
latent_padded = torch.zeros(target_dim)
latent_padded[:len(latent_flat)] = latent_flat
latent_pad = latent_padded.reshape(C, target_h, target_w)
# Check gradient preservation
# For interpolation, adjacent pixels should have smooth gradients
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
# For padding, there will be abrupt changes (gradient to zero)
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print("\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
# Padding introduces larger gradients due to abrupt zeros
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
def pytest_configure(config):
"""
Configure performance benchmarking markers
"""
config.addinivalue_line(
"markers",
"performance: mark test to verify CDC-FM computational performance"
)
config.addinivalue_line(
"markers",
"noise_distribution: mark test to verify noise injection properties"
)
config.addinivalue_line(
"markers",
"interpolation: mark test to verify interpolation quality"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -29,7 +29,8 @@ class TestCDCPreprocessorIntegration:
Test basic CDC preprocessing with small dataset
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 10 small latents
@@ -51,8 +52,9 @@ class TestCDCPreprocessorIntegration:
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure
cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz"
# Verify first CDC file structure (with config hash)
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
assert cdc_path.exists()
import numpy as np
@@ -73,7 +75,8 @@ class TestCDCPreprocessorIntegration:
Test CDC preprocessing with variable-size latents (bucketing)
"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 5 latents of shape (16, 4, 4)
@@ -109,9 +112,15 @@ class TestCDCPreprocessorIntegration:
assert files_saved == 10
import numpy as np
# Check shapes are stored in individual files
data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz")
data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz")
# Check shapes are stored in individual files (with config hash)
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
@@ -128,7 +137,8 @@ class TestDeviceConsistency:
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
@@ -148,7 +158,7 @@ class TestDeviceConsistency:
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu")
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
@@ -175,7 +185,8 @@ class TestDeviceConsistency:
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
@@ -195,7 +206,7 @@ class TestDeviceConsistency:
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu")
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Create noise and timesteps
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
@@ -236,7 +247,8 @@ class TestCDCEndToEnd:
"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
num_samples = 10
@@ -257,8 +269,8 @@ class TestCDCEndToEnd:
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu")
# Step 2: Load with GammaBDataset (use config hash)
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3

View File

@@ -1,237 +0,0 @@
"""
Tests to validate the CDC rescaling recommendations from paper review.
These tests check:
1. Gamma parameter interaction with rescaling
2. Spatial adaptivity of eigenvalue scaling
3. Verification of fixed vs adaptive rescaling behavior
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestGammaRescalingInteraction:
"""Test that gamma parameter works correctly with eigenvalue rescaling"""
def test_gamma_scales_eigenvalues_correctly(self, tmp_path):
"""Verify gamma multiplier is applied correctly after rescaling"""
# Create two preprocessors with different gamma values
gamma_values = [0.5, 1.0, 2.0]
eigenvalue_results = {}
for gamma in gamma_values:
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu"
)
# Add identical deterministic data for all runs
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_gamma_{gamma}.safetensors"
preprocessor.compute_all(save_path=output_path)
# Extract eigenvalues
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = f.get_tensor("eigenvalues/test_image_0").numpy()
eigenvalue_results[gamma] = eigvals
# With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound
# Gamma 0.5: max eigenvalue should be ~0.5
# Gamma 1.0: max eigenvalue should be ~1.0
# Gamma 2.0: max eigenvalue should be ~2.0
max_0p5 = np.max(eigenvalue_results[0.5])
max_1p0 = np.max(eigenvalue_results[1.0])
max_2p0 = np.max(eigenvalue_results[2.0])
assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}"
assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}"
assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}"
# All should have min of 1e-3 (clamp lower bound)
assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3
assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3
assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3
print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}")
print(f"✓ Gamma 1.0 max: {max_1p0:.4f}")
print(f"✓ Gamma 2.0 max: {max_2p0:.4f}")
def test_large_gamma_maintains_reasonable_scale(self, tmp_path):
"""Verify that large gamma values don't cause eigenvalue explosion"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu"
)
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_large_gamma.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
max_eigval = np.max(all_eigvals)
mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6])
# With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0
# But they should still be reasonable (not exploding)
assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma"
assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma"
print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}")
class TestSpatialAdaptivityOfRescaling:
"""Test spatial variation in eigenvalue scaling"""
def test_eigenvalues_vary_spatially(self, tmp_path):
"""Verify eigenvalues differ across spatially separated clusters"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create two distinct clusters in latent space
# Cluster 1: Tight cluster (low variance) - deterministic spread
for i in range(10):
latent = torch.zeros(16, 4, 4)
# Small variation around 0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Cluster 2: Loose cluster (high variance) - deterministic spread
for i in range(10, 20):
latent = torch.ones(16, 4, 4) * 5.0
# Large variation around 5.0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_spatial_variation.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
# Get eigenvalues from both clusters
cluster1_eigvals = []
cluster2_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster1_eigvals.append(np.max(eigvals))
for i in range(10, 20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster2_eigvals.append(np.max(eigvals))
cluster1_mean = np.mean(cluster1_eigvals)
cluster2_mean = np.mean(cluster2_eigvals)
print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}")
print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}")
# With fixed target_scale rescaling, eigenvalues should be similar
# despite different local geometry
# This demonstrates the limitation of fixed rescaling
ratio = cluster2_mean / (cluster1_mean + 1e-10)
print(f"✓ Ratio (loose/tight): {ratio:.2f}")
# Both should be rescaled to similar magnitude (~0.1 due to target_scale)
assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range"
assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range"
class TestFixedVsAdaptiveRescaling:
"""Compare current fixed rescaling vs paper's adaptive approach"""
def test_current_rescaling_is_uniform(self, tmp_path):
"""Demonstrate that current rescaling produces uniform eigenvalue scales"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create samples with varying local density - deterministic
for i in range(20):
latent = torch.zeros(16, 4, 4)
# Some samples clustered, some isolated
if i < 10:
# Dense cluster around origin
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05
else:
# Isolated points - larger offset
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_uniform_rescaling.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
max_eigenvalues = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
vals = eigvals[eigvals > 1e-6]
if vals.size: # at least one valid eigen-value
max_eigenvalues.append(vals.max())
if not max_eigenvalues: # safeguard against empty list
pytest.skip("no valid eigen-values found")
max_eigenvalues = np.array(max_eigenvalues)
# Check coefficient of variation (std / mean)
cv = max_eigenvalues.std() / max_eigenvalues.mean()
print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]")
print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}")
print(f"✓ Coefficient of variation: {cv:.4f}")
# With clamping, eigenvalues should have relatively low variation
assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping"
# Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0])
assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,132 +1,176 @@
"""
Standalone tests for CDC-FM integration.
Standalone tests for CDC-FM per-file caching.
These tests focus on CDC-FM specific functionality without importing
the full training infrastructure that has problematic dependencies.
These tests focus on the current CDC-FM per-file caching implementation
with hash-based cache validation.
"""
from pathlib import Path
import pytest
import torch
from safetensors.torch import save_file
import numpy as np
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPreprocessor:
"""Test CDC preprocessing functionality"""
"""Test CDC preprocessing functionality with per-file caching"""
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
"""Test basic CDC preprocessing with small dataset"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Compute and save (creates per-file CDC caches)
files_saved = preprocessor.compute_all()
# Verify file was created
assert Path(result_path).exists()
# Verify files were created
assert files_saved == 10
# Verify structure
from safetensors import safe_open
# Verify first CDC file structure
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
assert cdc_path.exists()
with safe_open(str(result_path), framework="pt", device="cpu") as f:
assert f.get_tensor("metadata/num_samples").item() == 10
assert f.get_tensor("metadata/k_neighbors").item() == 5
assert f.get_tensor("metadata/d_cdc").item() == 4
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check first sample
eigvecs = f.get_tensor("eigenvectors/test_image_0")
eigvals = f.get_tensor("eigenvalues/test_image_0")
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_cdc_preprocessor_different_shapes(self, tmp_path):
"""Test CDC preprocessing with variable-size latents (bucketing)"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
output_path = tmp_path / "test_gamma_b_multi.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
from safetensors import safe_open
assert files_saved == 10
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check shapes are stored
shape_0 = f.get_tensor("shapes/test_image_0")
shape_5 = f.get_tensor("shapes/test_image_5")
# Check shapes are stored in individual files
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
)
assert tuple(shape_0.tolist()) == (16, 4, 4)
assert tuple(shape_5.tolist()) == (16, 8, 8)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestGammaBDataset:
"""Test GammaBDataset loading and retrieval"""
"""Test GammaBDataset loading and retrieval with per-file caching"""
@pytest.fixture
def sample_cdc_cache(self, tmp_path):
"""Create a sample CDC cache file for testing"""
cache_path = tmp_path / "test_gamma_b.safetensors"
"""Create sample CDC cache files for testing"""
# Use 20 samples to ensure proper k-NN computation
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)],
adaptive_k=True, # Enable adaptive k for small dataset
min_bucket_size=5
)
# Create mock Γ_b data for 5 samples
tensors = {
"metadata/num_samples": torch.tensor([5]),
"metadata/k_neighbors": torch.tensor([10]),
"metadata/d_cdc": torch.tensor([4]),
"metadata/gamma": torch.tensor([1.0]),
}
# Create 20 samples
latents_npz_paths = []
for i in range(20):
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add shape and CDC data for each sample
for i in range(5):
tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W
tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d
tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive
save_file(tensors, str(cache_path))
return cache_path
preprocessor.compute_all()
return tmp_path, latents_npz_paths, preprocessor.config_hash
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
"""Test that GammaBDataset loads metadata correctly"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
"""Test that GammaBDataset loads CDC files correctly"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
assert gamma_b_dataset.num_samples == 5
assert gamma_b_dataset.d_cdc == 4
# Get components for first sample
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu")
# Check shapes
assert eigvecs.shape[0] == 1 # batch size
assert eigvecs.shape[1] == 4 # d_cdc
assert eigvals.shape == (1, 4) # batch, d_cdc
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
"""Test retrieving Γ_b^(1/2) components"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get Γ_b for indices [0, 2, 4]
indices = [0, 2, 4]
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu")
# Get Γ_b for paths [0, 2, 4]
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
# Check shapes
assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d)
assert eigenvectors.shape[0] == 3 # batch
assert eigenvectors.shape[1] == 4 # d_cdc
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
# Check values are positive
@@ -134,14 +178,16 @@ class TestGammaBDataset:
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns x unchanged at t=0"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Create test latents (batch of 3, matching d=1024 flattened)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.zeros(3) # t = 0 for all samples
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu")
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -150,13 +196,15 @@ class TestGammaBDataset:
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns correct shape"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(2, 1024) # B, d (flattened)
t = torch.tensor([0.3, 0.7])
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu")
paths = [latents_npz_paths[1], latents_npz_paths[3]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -165,13 +213,15 @@ class TestGammaBDataset:
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
"""Test compute_sigma_t_x produces finite values"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.rand(3) # Random timesteps in [0, 1]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu")
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -187,31 +237,39 @@ class TestCDCEndToEnd:
"""Test complete workflow: preprocess -> save -> load -> use"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
output_path = tmp_path / "cdc_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
assert gamma_b_dataset.num_samples == num_samples
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)

View File

@@ -1,178 +0,0 @@
"""
Test warning throttling for CDC shape mismatches.
Ensures that duplicate warnings for the same sample are not logged repeatedly.
"""
import pytest
import torch
import logging
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
class TestWarningThrottling:
"""Test that shape mismatch warnings are throttled"""
@pytest.fixture(autouse=True)
def clear_warned_samples(self):
"""Clear the warned samples set before each test"""
_cdc_warned_samples.clear()
yield
_cdc_warned_samples.clear()
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache with one shape"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with one specific shape
preprocessed_shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cache_path = tmp_path / "test_throttle.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog):
"""
Test that shape mismatch warning is only logged once per sample.
Even if the same sample appears in multiple batches, only warn once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
# Use different shape at runtime to trigger mismatch
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0'] # Same sample
# First call - should warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise1,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have exactly one warning
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 1, "First call should produce exactly one warning"
assert "CDC shape mismatch" in warnings[0].message
# Second call with same sample - should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise2,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Second call with same sample should not warn"
# Third call with same sample - still should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise3,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Third call should still not warn"
def test_different_samples_each_get_one_warning(self, cdc_cache, caplog):
"""
Test that different samples each get their own warning.
Each unique sample should be warned about once.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
# First batch: samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 3 warnings (one per sample)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 3, "Should warn for each of the 3 samples"
# Second batch: same samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings (already warned)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Should not warn again for same samples"
# Third batch: new samples 3, 4
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_3', 'test_image_4']
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 2 warnings (new samples)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])