Compare commits

...

3 Commits

Author SHA1 Message Date
rockerBOO
1f79115c6c Consolidate and simplify CDC test files
- Merged redundant test files
- Removed 'comprehensive' from file and docstring names
- Improved test organization and clarity
- Ensured all tests continue to pass
- Simplified test documentation
2025-10-11 17:48:08 -04:00
rockerBOO
8089cb6925 Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
2025-10-11 17:17:09 -04:00
rockerBOO
aa3a216106 Slight cleanup 2025-10-11 16:15:35 -04:00
14 changed files with 2158 additions and 126 deletions

View File

@@ -150,7 +150,7 @@ class CarreDuChampComputer:
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
# Move to GPU for SVD (100x speedup!)
# Move to GPU for SVD
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
self.device, dtype=torch.float32
)
@@ -354,9 +354,11 @@ class LatentBatcher:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
@@ -364,6 +366,15 @@ class LatentBatcher:
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
@@ -761,7 +772,7 @@ class GammaBDataset:
t = t.view(-1, 1)
# Early return for t=0 to avoid numerical errors
if torch.allclose(t, torch.zeros_like(t), atol=1e-8):
if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8):
return x.reshape(orig_shape)
# Check if CDC is disabled (all eigenvalues are zero)

View File

@@ -6,8 +6,6 @@ Verifies that adaptive k properly adjusts based on bucket sizes.
import pytest
import torch
import numpy as np
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset

View File

@@ -0,0 +1,183 @@
import torch
from typing import Union
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCAdvanced:
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_gradient_flow_preservation(self):
"""
Verify that gradient flow is preserved even for near-zero time steps
with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
# Optional: Print gradient details for debugging
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
def test_gradient_flow_with_different_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
if t.grad is not None:
t.grad.zero_()
if latent.grad is not None:
latent.grad.zero_()
def pytest_configure(config):
"""
Add custom markers for CDC-FM tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)

View File

@@ -0,0 +1,146 @@
"""
Test CDC-FM dimension handling and fallback mechanisms.
This module tests the behavior of the CDC Flow Matching implementation
when encountering latents with different dimensions.
"""
import torch
import logging
import tempfile
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestDimensionHandling:
def setup_method(self):
"""Prepare consistent test environment"""
self.logger = logging.getLogger(__name__)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mixed_dimension_fallback(self):
"""
Verify that preprocessor falls back to standard noise for mixed-dimension batches
"""
# Prepare preprocessor with debug mode
preprocessor = CDCPreprocessor(debug=True)
# Different-sized latents (3D: channels, height, width)
latents = [
torch.randn(3, 32, 64), # First latent: 3x32x64
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
# Try adding mixed-dimension latents
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_mixed_image_{i}'}
)
try:
cdc_path = preprocessor.compute_all(tmp_file.name)
except ValueError as e:
# If implementation raises ValueError, that's acceptable
assert "Dimension mismatch" in str(e)
return
# Check for dimension-related log messages
dimension_warnings = [
msg for msg in log_messages
if "dimension mismatch" in msg.lower()
]
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
# Load results and verify fallback
dataset = GammaBDataset(cdc_path)
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
# Check metadata about samples with/without CDC
assert dataset.num_samples == len(latents), "All samples should be processed"
def test_adaptive_k_with_dimension_constraints(self):
"""
Test adaptive k-neighbors behavior with dimension constraints
"""
# Prepare preprocessor with adaptive k and small bucket size
preprocessor = CDCPreprocessor(
adaptive_k=True,
min_bucket_size=5,
debug=True
)
# Generate latents with similar but not identical dimensions
base_latent = torch.randn(3, 32, 64)
similar_latents = [
base_latent,
torch.randn(3, 32, 65), # Slightly different dimension
torch.randn(3, 32, 66) # Another slightly different dimension
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add similar latents
for i, latent in enumerate(similar_latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_adaptive_k_image_{i}'}
)
cdc_path = preprocessor.compute_all(tmp_file.name)
# Load results
dataset = GammaBDataset(cdc_path)
# Verify samples processed
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
# Optional: Check warnings about dimension differences
dimension_warnings = [
msg for msg in log_messages
if "dimension" in msg.lower()
]
print(f"Dimension-related warnings: {dimension_warnings}")
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
def pytest_configure(config):
"""
Configure custom markers for dimension handling tests
"""
config.addinivalue_line(
"markers",
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
)

View File

@@ -0,0 +1,310 @@
"""
Comprehensive CDC Dimension Handling and Warning Tests
This module tests:
1. Dimension mismatch detection and fallback mechanisms
2. Warning throttling for shape mismatches
3. Adaptive k-neighbors behavior with dimension constraints
"""
import pytest
import torch
import logging
import tempfile
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
class TestDimensionHandlingAndWarnings:
"""
Comprehensive testing of dimension handling, noise injection, and warning systems
"""
@pytest.fixture(autouse=True)
def clear_warned_samples(self):
"""Clear the warned samples set before each test"""
_cdc_warned_samples.clear()
yield
_cdc_warned_samples.clear()
def test_mixed_dimension_fallback(self):
"""
Verify that preprocessor falls back to standard noise for mixed-dimension batches
"""
# Prepare preprocessor with debug mode
preprocessor = CDCPreprocessor(debug=True)
# Different-sized latents (3D: channels, height, width)
latents = [
torch.randn(3, 32, 64), # First latent: 3x32x64
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
# Try adding mixed-dimension latents
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_mixed_image_{i}'}
)
try:
cdc_path = preprocessor.compute_all(tmp_file.name)
except ValueError as e:
# If implementation raises ValueError, that's acceptable
assert "Dimension mismatch" in str(e)
return
# Check for dimension-related log messages
dimension_warnings = [
msg for msg in log_messages
if "dimension mismatch" in msg.lower()
]
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
# Load results and verify fallback
dataset = GammaBDataset(cdc_path)
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
# Check metadata about samples with/without CDC
assert dataset.num_samples == len(latents), "All samples should be processed"
def test_adaptive_k_with_dimension_constraints(self):
"""
Test adaptive k-neighbors behavior with dimension constraints
"""
# Prepare preprocessor with adaptive k and small bucket size
preprocessor = CDCPreprocessor(
adaptive_k=True,
min_bucket_size=5,
debug=True
)
# Generate latents with similar but not identical dimensions
base_latent = torch.randn(3, 32, 64)
similar_latents = [
base_latent,
torch.randn(3, 32, 65), # Slightly different dimension
torch.randn(3, 32, 66) # Another slightly different dimension
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add similar latents
for i, latent in enumerate(similar_latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_adaptive_k_image_{i}'}
)
cdc_path = preprocessor.compute_all(tmp_file.name)
# Load results
dataset = GammaBDataset(cdc_path)
# Verify samples processed
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
# Optional: Check warnings about dimension differences
dimension_warnings = [
msg for msg in log_messages
if "dimension" in msg.lower()
]
print(f"Dimension-related warnings: {dimension_warnings}")
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
def test_warning_only_logged_once_per_sample(self, caplog):
"""
Test that shape mismatch warning is only logged once per sample.
Even if the same sample appears in multiple batches, only warn once.
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with one specific shape
preprocessed_shape = (16, 32, 32)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Use different shape at runtime to trigger mismatch
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0'] # Same sample
# First call - should warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise1,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have exactly one warning
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 1, "First call should produce exactly one warning"
assert "CDC shape mismatch" in warnings[0].message
# Second call with same sample - should NOT warn
with caplog.at_level(logging.WARNING):
caplog.clear()
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise2,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Second call with same sample should not warn"
def test_different_samples_each_get_one_warning(self, caplog):
"""
Test that different samples each get their own warning.
Each unique sample should be warned about once.
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create cache with specific shape
preprocessed_shape = (16, 32, 32)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i in range(10):
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
runtime_shape = (16, 64, 64)
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
# First batch: samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 3 warnings (one per sample)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 3, "Should warn for each of the 3 samples"
# Second batch: same samples 0, 1, 2
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have NO warnings (already warned)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 0, "Should not warn again for same samples"
# Third batch: new samples 3, 4
with caplog.at_level(logging.WARNING):
caplog.clear()
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
image_keys = ['test_image_3', 'test_image_4']
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Should have 2 warnings (new samples)
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
def pytest_configure(config):
"""
Configure custom markers for dimension handling and warning tests
"""
config.addinivalue_line(
"markers",
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
)
config.addinivalue_line(
"markers",
"warning_throttling: mark test for CDC-FM warning suppression"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,164 @@
"""
Tests using realistic high-dimensional data to catch scaling bugs.
This test uses realistic VAE-like latents to ensure eigenvalue normalization
works correctly on real-world data.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestRealisticDataScaling:
"""Test eigenvalue scaling with realistic high-dimensional data"""
def test_high_dimensional_latents_not_saturated(self, tmp_path):
"""
Verify that high-dimensional realistic latents don't saturate eigenvalues.
This test simulates real FLUX training data:
- High dimension (16×64×64 = 65536)
- Varied content (different variance in different regions)
- Realistic magnitude (VAE output scale)
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
# Different channels have different patterns (realistic for VAE)
for c in range(16):
# Some channels have gradients
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
# Some channels have patterns
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
# Some channels are more uniform
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation (different "subjects")
latent = latent * (1.0 + i * 0.2)
# Add realistic VAE-like noise/variation
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are NOT all saturated at 1.0
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical: eigenvalues should NOT all be 1.0
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# FAIL if too many eigenvalues are saturated at 1.0
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"This indicates the normalization bug - raw eigenvalues are not being "
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same value."
)
# Mean should be in reasonable range (not all 1.0)
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
"""
Test that datasets with more variance produce more diverse eigenvalues.
This ensures the normalization preserves relative information.
"""
# Create two preprocessors with different data variance
results = {}
for variance_scale in [0.5, 2.0]:
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(15):
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
# Create varied patterns
for c in range(16):
for h in range(32):
for w in range(32):
latent[c, h, w] = (
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = []
for i in range(15):
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
eigvals.extend(ev[ev > 1e-6])
results[variance_scale] = {
'mean': np.mean(eigvals),
'std': np.std(eigvals),
'range': (np.min(eigvals), np.max(eigvals))
}
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
# Both should have diversity (not saturated)
for scale in [0.5, 2.0]:
assert results[scale]['std'] > 0.1, (
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,220 @@
"""
Comprehensive CDC Eigenvalue Validation Tests
These tests ensure that eigenvalue computation and scaling work correctly
across various scenarios, including:
- Scaling to reasonable ranges
- Handling high-dimensional data
- Preserving latent information
- Preventing computational artifacts
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestEigenvalueScaling:
"""Verify eigenvalue scaling and computational properties"""
def test_eigenvalues_in_correct_range(self, tmp_path):
"""
Verify eigenvalues are scaled to ~0.01-1.0 range, not millions.
Ensures:
- No numerical explosions
- Reasonable eigenvalue magnitudes
- Consistent scaling across samples
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create deterministic latents with structured patterns
for i in range(10):
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
latent = latent + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are in correct range
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical assertions for eigenvalue scale
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
# Check sqrt (used in noise) is reasonable
sqrt_max = np.sqrt(all_eigvals.max())
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
print(f"✓ sqrt(max): {sqrt_max:.4f}")
def test_high_dimensional_latents_scaling(self, tmp_path):
"""
Verify scaling for high-dimensional realistic latents.
Key scenarios:
- High-dimensional data (16×64×64)
- Varied channel structures
- Realistic VAE-like data
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
for c in range(16):
# Different patterns across channels
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation
latent = latent * (1.0 + i * 0.2)
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are not all saturated
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# Fail if too many eigenvalues are saturated
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"Raw eigenvalues not scaled before clamping. "
f"Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same."
)
# Mean should be in reasonable range
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_noise_magnitude_reasonable(self, tmp_path):
"""
Verify CDC noise has reasonable magnitude for training.
Ensures noise:
- Has similar scale to input latents
- Won't destabilize training
- Preserves input variance
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Load and compute noise
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Simulate training scenario with deterministic data
batch_size = 3
latents = torch.zeros(batch_size, 16, 4, 4)
for b in range(batch_size):
for c in range(16):
for h in range(4):
for w in range(4):
latents[b, c, h, w] = (b + c + h + w) / 24.0
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
# Check noise magnitude
noise_std = noise.std().item()
latent_std = latents.std().item()
# Noise should be similar magnitude to input latents (within 10x)
ratio = noise_std / latent_std
assert 0.1 < ratio < 10.0, (
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
)
# Simulated MSE loss should be reasonable
simulated_loss = torch.mean((noise - latents) ** 2).item()
assert simulated_loss < 100.0, (
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
f"Should be O(0.1-1.0) for stable training."
)
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,52 +1,209 @@
"""
Test gradient flow through CDC noise transformation.
CDC Gradient Flow Verification Tests
Ensures that gradients propagate correctly through both fast and slow paths.
This module provides testing of:
1. Mock dataset gradient preservation
2. Real dataset gradient flow
3. Various time steps and computation paths
4. Fallback and edge case scenarios
"""
import pytest
import torch
import tempfile
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestCDCGradientFlow:
"""Test gradient flow through CDC transformations"""
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache"""
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCGradientFlow:
"""
Gradient flow testing for CDC noise transformations
"""
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mock_gradient_flow_near_zero_time_step(self):
"""
Verify gradient flow preservation for near-zero time steps
using mock dataset with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
def test_gradient_flow_with_multiple_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
t.grad.zero_() if t.grad is not None else None
latent.grad.zero_() if latent.grad is not None else None
def test_gradient_flow_with_real_dataset(self, tmp_path):
"""
Test gradient flow with real CDC dataset
"""
# Create cache with uniform shapes
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create samples with same shape for fast path testing
shape = (16, 32, 32)
for i in range(20):
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
def test_gradient_flow_fast_path(self, cdc_cache):
"""
Test that gradients flow correctly through batch processing (fast path).
All samples have matching shapes, so CDC uses batch processing.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
batch_size = 4
shape = (16, 32, 32)
# Create input noise with requires_grad
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
# Prepare test noise
torch.manual_seed(42)
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
@@ -60,102 +217,23 @@ class TestCDCGradientFlow:
device="cpu"
)
# Ensure output requires grad
# Verify gradient flow
assert noise_out.requires_grad, "Output should require gradients"
# Compute a simple loss and backprop
loss = noise_out.sum()
loss.backward()
# Verify gradients were computed for input
assert noise.grad is not None, "Gradients should flow back to input noise"
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
def test_gradient_flow_slow_path_all_match(self, cdc_cache):
def test_gradient_flow_with_fallback(self, tmp_path):
"""
Test gradient flow when slow path is taken but all shapes match.
Test gradient flow when using Gaussian fallback (shape mismatch)
This tests the per-sample loop with CDC transformation.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
batch_size = 4
shape = (16, 32, 32)
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
# Apply transformation
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Test gradient flow
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None
assert not torch.isnan(noise.grad).any()
assert (noise.grad != 0).any()
def test_gradient_consistency_between_paths(self, tmp_path):
"""
Test that fast path and slow path produce similar gradients.
When all shapes match, both paths should give consistent results.
"""
# Create cache with uniform shapes
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_consistency.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Same input for both tests
torch.manual_seed(42)
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
# Apply CDC (should use fast path)
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Compute gradients
loss = noise_out.sum()
loss.backward()
# Both paths should produce valid gradients
assert noise.grad is not None
assert not torch.isnan(noise.grad).any()
def test_fallback_gradient_flow(self, tmp_path):
"""
Test gradient flow when using Gaussian fallback (shape mismatch).
Ensures that cloned tensors maintain gradient flow correctly.
Ensures that cloned tensors maintain gradient flow correctly
even when shape mismatch triggers Gaussian noise
"""
# Create cache with one shape
preprocessor = CDCPreprocessor(
@@ -167,7 +245,7 @@ class TestCDCGradientFlow:
metadata = {'image_key': 'test_image_0'}
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
cache_path = tmp_path / "test_fallback.safetensors"
cache_path = tmp_path / "test_fallback_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
@@ -178,7 +256,6 @@ class TestCDCGradientFlow:
image_keys = ['test_image_0']
# Apply transformation (should fallback to Gaussian for this sample)
# Note: This will log a warning but won't raise
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
@@ -195,8 +272,26 @@ class TestCDCGradientFlow:
loss.backward()
assert noise.grad is not None, "Gradients should flow even in fallback case"
assert not torch.isnan(noise.grad).any()
assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN"
def pytest_configure(config):
"""
Configure custom markers for CDC gradient flow tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)
config.addinivalue_line(
"markers",
"mock_dataset: mark test using mock dataset for simplified gradient testing"
)
config.addinivalue_line(
"markers",
"real_dataset: mark test using real dataset for comprehensive gradient testing"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
pytest.main([__file__, "-v", "-s"])

View File

@@ -4,7 +4,6 @@ Test comparing interpolation vs pad/truncate for CDC preprocessing.
This test quantifies the difference between the two approaches.
"""
import numpy as np
import pytest
import torch
import torch.nn.functional as F
@@ -89,16 +88,16 @@ class TestInterpolationComparison:
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(f" BUT the intermediate representation is corrupted with zeros!")
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(" BUT the intermediate representation is corrupted with zeros!")
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
@@ -151,7 +150,7 @@ class TestInterpolationComparison:
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print(f"\nGradient smoothness (lower is smoother):")
print("\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")

View File

@@ -0,0 +1,412 @@
"""
Performance and Interpolation Tests for CDC Flow Matching
This module provides testing of:
1. Computational overhead
2. Noise injection properties
3. Interpolation vs. pad/truncate methods
4. Spatial structure preservation
"""
import pytest
import torch
import time
import tempfile
import numpy as np
import torch.nn.functional as F
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPerformanceAndInterpolation:
"""
Comprehensive performance testing for CDC Flow Matching
Covers computational efficiency, noise properties, and interpolation quality
"""
@pytest.fixture(params=[
(3, 32, 32), # Small latent: typical for compact representations
(3, 64, 64), # Medium latent: standard feature maps
(3, 128, 128) # Large latent: high-resolution feature spaces
])
def latent_sizes(self, request):
"""
Parametrized fixture generating test cases for different latent sizes.
Rationale:
- Tests robustness across various computational scales
- Ensures consistent behavior from compact to large representations
- Identifies potential dimensionality-related performance bottlenecks
"""
return request.param
def test_computational_overhead(self, latent_sizes):
"""
Measure computational overhead of CDC preprocessing across latent sizes.
Performance Verification Objectives:
1. Verify preprocessing time scales predictably with input dimensions
2. Ensure adaptive k-neighbors works efficiently
3. Validate computational overhead remains within acceptable bounds
Performance Metrics:
- Total preprocessing time
- Per-sample processing time
- Computational complexity indicators
"""
# Tuned preprocessing configuration
preprocessor = CDCPreprocessor(
k_neighbors=256, # Comprehensive neighborhood exploration
d_cdc=8, # Geometric embedding dimensionality
debug=True, # Enable detailed performance logging
adaptive_k=True # Dynamic neighborhood size adjustment
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42) # Consistent random generation
# Generate representative latent batch
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
# Precision timing of preprocessing
start_time = time.perf_counter()
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with traceable metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'perf_test_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Calculate precise preprocessing metrics
end_time = time.perf_counter()
preprocessing_time = end_time - start_time
per_sample_time = preprocessing_time / batch_size
# Performance reporting and assertions
input_volume = np.prod(latent_sizes)
time_complexity_indicator = preprocessing_time / input_volume
print(f"\nPerformance Breakdown:")
print(f" Latent Size: {latent_sizes}")
print(f" Total Samples: {batch_size}")
print(f" Input Volume: {input_volume}")
print(f" Total Time: {preprocessing_time:.4f} seconds")
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
# Adaptive thresholds based on input dimensions
max_total_time = 10.0 # Base threshold
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
# Different time complexity thresholds for different latent sizes
max_time_complexity = (
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
1e-4 # Standard latents
)
# Performance assertions with informative error messages
assert preprocessing_time < max_total_time, (
f"Total preprocessing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Total Time: {preprocessing_time:.4f} seconds\n"
f" Threshold: {max_total_time} seconds"
)
assert per_sample_time < max_per_sample_time, (
f"Per-sample processing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
f" Threshold: {max_per_sample_time} seconds"
)
# More adaptable time complexity check
assert time_complexity_indicator < max_time_complexity, (
f"Time complexity scaling exceeded expectations!\n"
f" Latent Size: {latent_sizes}\n"
f" Input Volume: {input_volume}\n"
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
f" Threshold: {max_time_complexity} seconds/voxel"
)
def test_noise_distribution(self, latent_sizes):
"""
Verify CDC noise injection quality and properties.
Based on test plan objectives:
1. CDC noise is actually being generated (not all Gaussian fallback)
2. Eigenvalues are valid (non-negative, bounded)
3. CDC components are finite and usable for noise generation
"""
preprocessor = CDCPreprocessor(
k_neighbors=16, # Reduced to match batch size
d_cdc=8,
gamma=1.0,
debug=True,
adaptive_k=True
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42)
# Generate batch of latents
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'noise_dist_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Analyze noise properties
dataset = GammaBDataset(cdc_path)
# Track samples that used CDC vs Gaussian fallback
cdc_samples = 0
gaussian_samples = 0
eigenvalue_stats = {
'min': float('inf'),
'max': float('-inf'),
'mean': 0.0,
'sum': 0.0
}
# Verify each sample's CDC components
for i in range(batch_size):
image_key = f'noise_dist_image_{i}'
# Get eigenvectors and eigenvalues
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
# Skip zero eigenvectors (fallback case)
if torch.all(eigvecs[0] == 0):
gaussian_samples += 1
continue
# Get the top d_cdc eigenvectors and eigenvalues
top_eigvecs = eigvecs[0] # (d_cdc, d)
top_eigvals = eigvals[0] # (d_cdc,)
# Basic validity checks
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
# Update statistics
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
eigenvalue_stats['sum'] += top_eigvals.sum().item()
cdc_samples += 1
# Compute mean eigenvalue across all CDC samples
if cdc_samples > 0:
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
# Print final statistics
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
print(f" CDC samples: {cdc_samples}/{batch_size}")
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
# Assertions based on plan objectives
# 1. CDC noise should be generated for most samples
assert cdc_samples > 0, "No samples used CDC noise injection"
assert gaussian_samples < batch_size // 2, (
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
)
# 2. Eigenvalues should be valid (non-negative and bounded)
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
# 3. Mean eigenvalue should be reasonable (not degenerate)
assert eigenvalue_stats['mean'] > 0.05, (
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
"suggests degenerate CDC components"
)
def test_interpolation_reconstruction(self):
"""
Compare interpolation vs pad/truncate reconstruction methods for CDC.
"""
# Create test latents with different sizes - deterministic
latent_small = torch.zeros(16, 4, 4)
for c in range(16):
for h in range(4):
for w in range(4):
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
latent_large = torch.zeros(16, 8, 8)
for c in range(16):
for h in range(8):
for w in range(8):
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
target_h, target_w = 6, 6 # Median size
# Method 1: Interpolation
def interpolate_method(latent, target_h, target_w):
latent_input = latent.unsqueeze(0) # (1, C, H, W)
latent_resized = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
)
# Resize back
C, H, W = latent.shape
latent_reconstructed = F.interpolate(
latent_resized, size=(H, W), mode='bilinear', align_corners=False
)
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
return relative_error
# Method 2: Pad/Truncate
def pad_truncate_method(latent, target_h, target_w):
C, H, W = latent.shape
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
current_dim = C * H * W
if current_dim == target_dim:
latent_resized_flat = latent_flat
elif current_dim > target_dim:
# Truncate
latent_resized_flat = latent_flat[:target_dim]
else:
# Pad
latent_resized_flat = torch.zeros(target_dim)
latent_resized_flat[:current_dim] = latent_flat
# Resize back
if current_dim == target_dim:
latent_reconstructed_flat = latent_resized_flat
elif current_dim > target_dim:
# Pad back
latent_reconstructed_flat = torch.zeros(current_dim)
latent_reconstructed_flat[:target_dim] = latent_resized_flat
else:
# Truncate back
latent_reconstructed_flat = latent_resized_flat[:current_dim]
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
return relative_error
# Compare for small latent (needs padding)
interp_error_small = interpolate_method(latent_small, target_h, target_w)
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
# Compare for large latent (needs truncation)
interp_error_large = interpolate_method(latent_large, target_h, target_w)
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(" BUT the intermediate representation is corrupted with zeros!")
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
print("\nKey insight: For CDC, intermediate representation quality matters,")
print("not reconstruction error. Interpolation preserves spatial structure.")
# Verify interpolation errors are reasonable
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
def test_spatial_structure_preservation(self):
"""
Test that interpolation preserves spatial structure better than pad/truncate.
"""
# Create a latent with clear spatial pattern (gradient)
C, H, W = 16, 4, 4
latent = torch.zeros(C, H, W)
for i in range(H):
for j in range(W):
latent[:, i, j] = i * W + j # Gradient pattern
target_h, target_w = 6, 6
# Interpolation
latent_input = latent.unsqueeze(0)
latent_interp = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
).squeeze(0)
# Pad/truncate
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
latent_padded = torch.zeros(target_dim)
latent_padded[:len(latent_flat)] = latent_flat
latent_pad = latent_padded.reshape(C, target_h, target_w)
# Check gradient preservation
# For interpolation, adjacent pixels should have smooth gradients
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
# For padding, there will be abrupt changes (gradient to zero)
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print("\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
# Padding introduces larger gradients due to abrupt zeros
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
def pytest_configure(config):
"""
Configure performance benchmarking markers
"""
config.addinivalue_line(
"markers",
"performance: mark test to verify CDC-FM computational performance"
)
config.addinivalue_line(
"markers",
"noise_distribution: mark test to verify noise injection properties"
)
config.addinivalue_line(
"markers",
"interpolation: mark test to verify interpolation quality"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,260 @@
"""
CDC Preprocessor and Device Consistency Tests
This module provides testing of:
1. CDC Preprocessor functionality
2. Device consistency handling
3. GammaBDataset loading and usage
4. End-to-end CDC workflow verification
"""
import pytest
import logging
import torch
from pathlib import Path
from safetensors.torch import save_file
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestCDCPreprocessorIntegration:
"""
Comprehensive testing of CDC preprocessing and device handling
"""
def test_basic_preprocessor_workflow(self, tmp_path):
"""
Test basic CDC preprocessing with small dataset
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Compute and save
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify file was created
assert Path(result_path).exists()
# Verify structure
with safe_open(str(result_path), framework="pt", device="cpu") as f:
assert f.get_tensor("metadata/num_samples").item() == 10
assert f.get_tensor("metadata/k_neighbors").item() == 5
assert f.get_tensor("metadata/d_cdc").item() == 4
# Check first sample
eigvecs = f.get_tensor("eigenvectors/test_image_0")
eigvals = f.get_tensor("eigenvalues/test_image_0")
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_preprocessor_with_different_shapes(self, tmp_path):
"""
Test CDC preprocessing with variable-size latents (bucketing)
"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Compute and save
output_path = tmp_path / "test_gamma_b_multi.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify both shape groups were processed
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check shapes are stored
shape_0 = f.get_tensor("shapes/test_image_0")
shape_5 = f.get_tensor("shapes/test_image_5")
assert tuple(shape_0.tolist()) == (16, 4, 4)
assert tuple(shape_5.tolist()) == (16, 8, 8)
class TestDeviceConsistency:
"""
Test device handling and consistency for CDC transformations
"""
def test_matching_devices_no_warning(self, tmp_path, caplog):
"""
Test that no warnings are emitted when devices match.
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_device.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
with caplog.at_level(logging.WARNING):
caplog.clear()
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# No device mismatch warnings
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
assert len(device_warnings) == 0, "Should not warn when devices match"
def test_device_mismatch_handling(self, tmp_path):
"""
Test that CDC transformation handles device mismatch gracefully
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_device_mismatch.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Create noise and timesteps
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
image_keys = ['test_image_0', 'test_image_1']
# Perform CDC transformation
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Verify output characteristics
assert result.shape == noise.shape
assert result.device == noise.device
assert result.requires_grad # Gradients should still work
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()
# Verify gradients flow
loss = result.sum()
loss.backward()
assert noise.grad is not None
class TestCDCEndToEnd:
"""
End-to-end CDC workflow tests
"""
def test_full_preprocessing_usage_workflow(self, tmp_path):
"""
Test complete workflow: preprocess -> save -> load -> use
"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
num_samples = 10
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "cdc_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
assert gamma_b_dataset.num_samples == num_samples
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
def pytest_configure(config):
"""
Configure custom markers for CDC tests
"""
config.addinivalue_line(
"markers",
"device_consistency: mark test to verify device handling in CDC transformations"
)
config.addinivalue_line(
"markers",
"preprocessor: mark test to verify CDC preprocessing workflow"
)
config.addinivalue_line(
"markers",
"end_to_end: mark test to verify full CDC workflow"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,237 @@
"""
Tests to validate the CDC rescaling recommendations from paper review.
These tests check:
1. Gamma parameter interaction with rescaling
2. Spatial adaptivity of eigenvalue scaling
3. Verification of fixed vs adaptive rescaling behavior
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestGammaRescalingInteraction:
"""Test that gamma parameter works correctly with eigenvalue rescaling"""
def test_gamma_scales_eigenvalues_correctly(self, tmp_path):
"""Verify gamma multiplier is applied correctly after rescaling"""
# Create two preprocessors with different gamma values
gamma_values = [0.5, 1.0, 2.0]
eigenvalue_results = {}
for gamma in gamma_values:
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu"
)
# Add identical deterministic data for all runs
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_gamma_{gamma}.safetensors"
preprocessor.compute_all(save_path=output_path)
# Extract eigenvalues
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = f.get_tensor("eigenvalues/test_image_0").numpy()
eigenvalue_results[gamma] = eigvals
# With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound
# Gamma 0.5: max eigenvalue should be ~0.5
# Gamma 1.0: max eigenvalue should be ~1.0
# Gamma 2.0: max eigenvalue should be ~2.0
max_0p5 = np.max(eigenvalue_results[0.5])
max_1p0 = np.max(eigenvalue_results[1.0])
max_2p0 = np.max(eigenvalue_results[2.0])
assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}"
assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}"
assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}"
# All should have min of 1e-3 (clamp lower bound)
assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3
assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3
assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3
print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}")
print(f"✓ Gamma 1.0 max: {max_1p0:.4f}")
print(f"✓ Gamma 2.0 max: {max_2p0:.4f}")
def test_large_gamma_maintains_reasonable_scale(self, tmp_path):
"""Verify that large gamma values don't cause eigenvalue explosion"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu"
)
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_large_gamma.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
max_eigval = np.max(all_eigvals)
mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6])
# With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0
# But they should still be reasonable (not exploding)
assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma"
assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma"
print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}")
class TestSpatialAdaptivityOfRescaling:
"""Test spatial variation in eigenvalue scaling"""
def test_eigenvalues_vary_spatially(self, tmp_path):
"""Verify eigenvalues differ across spatially separated clusters"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create two distinct clusters in latent space
# Cluster 1: Tight cluster (low variance) - deterministic spread
for i in range(10):
latent = torch.zeros(16, 4, 4)
# Small variation around 0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Cluster 2: Loose cluster (high variance) - deterministic spread
for i in range(10, 20):
latent = torch.ones(16, 4, 4) * 5.0
# Large variation around 5.0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_spatial_variation.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
# Get eigenvalues from both clusters
cluster1_eigvals = []
cluster2_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster1_eigvals.append(np.max(eigvals))
for i in range(10, 20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster2_eigvals.append(np.max(eigvals))
cluster1_mean = np.mean(cluster1_eigvals)
cluster2_mean = np.mean(cluster2_eigvals)
print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}")
print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}")
# With fixed target_scale rescaling, eigenvalues should be similar
# despite different local geometry
# This demonstrates the limitation of fixed rescaling
ratio = cluster2_mean / (cluster1_mean + 1e-10)
print(f"✓ Ratio (loose/tight): {ratio:.2f}")
# Both should be rescaled to similar magnitude (~0.1 due to target_scale)
assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range"
assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range"
class TestFixedVsAdaptiveRescaling:
"""Compare current fixed rescaling vs paper's adaptive approach"""
def test_current_rescaling_is_uniform(self, tmp_path):
"""Demonstrate that current rescaling produces uniform eigenvalue scales"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create samples with varying local density - deterministic
for i in range(20):
latent = torch.zeros(16, 4, 4)
# Some samples clustered, some isolated
if i < 10:
# Dense cluster around origin
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05
else:
# Isolated points - larger offset
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_uniform_rescaling.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
max_eigenvalues = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
vals = eigvals[eigvals > 1e-6]
if vals.size: # at least one valid eigen-value
max_eigenvalues.append(vals.max())
if not max_eigenvalues: # safeguard against empty list
pytest.skip("no valid eigen-values found")
max_eigenvalues = np.array(max_eigenvalues)
# Check coefficient of variation (std / mean)
cv = max_eigenvalues.std() / max_eigenvalues.mean()
print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]")
print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}")
print(f"✓ Coefficient of variation: {cv:.4f}")
# With clamping, eigenvalues should have relatively low variation
assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping"
# Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0])
assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -5,10 +5,8 @@ These tests focus on CDC-FM specific functionality without importing
the full training infrastructure that has problematic dependencies.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
from safetensors.torch import save_file

View File

@@ -7,7 +7,6 @@ Ensures that duplicate warnings for the same sample are not logged repeatedly.
import pytest
import torch
import logging
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples