Improve dimension mismatch warning for CDC Flow Matching

- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
This commit is contained in:
rockerBOO
2025-10-11 17:17:09 -04:00
parent aa3a216106
commit 8089cb6925
11 changed files with 1014 additions and 13 deletions

View File

@@ -354,9 +354,11 @@ class LatentBatcher:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
@@ -364,6 +366,15 @@ class LatentBatcher:
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:

View File

@@ -6,8 +6,6 @@ Verifies that adaptive k properly adjusts based on bucket sizes.
import pytest
import torch
import numpy as np
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset

View File

@@ -0,0 +1,183 @@
import torch
from typing import Union
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCAdvanced:
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_gradient_flow_preservation(self):
"""
Verify that gradient flow is preserved even for near-zero time steps
with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
# Optional: Print gradient details for debugging
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
def test_gradient_flow_with_different_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
if t.grad is not None:
t.grad.zero_()
if latent.grad is not None:
latent.grad.zero_()
def pytest_configure(config):
"""
Add custom markers for CDC-FM tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)

View File

@@ -0,0 +1,146 @@
"""
Test CDC-FM dimension handling and fallback mechanisms.
This module tests the behavior of the CDC Flow Matching implementation
when encountering latents with different dimensions.
"""
import torch
import logging
import tempfile
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestDimensionHandling:
def setup_method(self):
"""Prepare consistent test environment"""
self.logger = logging.getLogger(__name__)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mixed_dimension_fallback(self):
"""
Verify that preprocessor falls back to standard noise for mixed-dimension batches
"""
# Prepare preprocessor with debug mode
preprocessor = CDCPreprocessor(debug=True)
# Different-sized latents (3D: channels, height, width)
latents = [
torch.randn(3, 32, 64), # First latent: 3x32x64
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
# Try adding mixed-dimension latents
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_mixed_image_{i}'}
)
try:
cdc_path = preprocessor.compute_all(tmp_file.name)
except ValueError as e:
# If implementation raises ValueError, that's acceptable
assert "Dimension mismatch" in str(e)
return
# Check for dimension-related log messages
dimension_warnings = [
msg for msg in log_messages
if "dimension mismatch" in msg.lower()
]
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
# Load results and verify fallback
dataset = GammaBDataset(cdc_path)
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
# Check metadata about samples with/without CDC
assert dataset.num_samples == len(latents), "All samples should be processed"
def test_adaptive_k_with_dimension_constraints(self):
"""
Test adaptive k-neighbors behavior with dimension constraints
"""
# Prepare preprocessor with adaptive k and small bucket size
preprocessor = CDCPreprocessor(
adaptive_k=True,
min_bucket_size=5,
debug=True
)
# Generate latents with similar but not identical dimensions
base_latent = torch.randn(3, 32, 64)
similar_latents = [
base_latent,
torch.randn(3, 32, 65), # Slightly different dimension
torch.randn(3, 32, 66) # Another slightly different dimension
]
# Use a mock handler to capture log messages
from library.cdc_fm import logger
log_messages = []
class LogCapture(logging.Handler):
def emit(self, record):
log_messages.append(record.getMessage())
# Temporarily add a capture handler
capture_handler = LogCapture()
logger.addHandler(capture_handler)
try:
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add similar latents
for i, latent in enumerate(similar_latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'test_adaptive_k_image_{i}'}
)
cdc_path = preprocessor.compute_all(tmp_file.name)
# Load results
dataset = GammaBDataset(cdc_path)
# Verify samples processed
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
# Optional: Check warnings about dimension differences
dimension_warnings = [
msg for msg in log_messages
if "dimension" in msg.lower()
]
print(f"Dimension-related warnings: {dimension_warnings}")
finally:
# Remove the capture handler
logger.removeHandler(capture_handler)
def pytest_configure(config):
"""
Configure custom markers for dimension handling tests
"""
config.addinivalue_line(
"markers",
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
)

View File

@@ -0,0 +1,164 @@
"""
Tests using realistic high-dimensional data to catch scaling bugs.
This test uses realistic VAE-like latents to ensure eigenvalue normalization
works correctly on real-world data.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestRealisticDataScaling:
"""Test eigenvalue scaling with realistic high-dimensional data"""
def test_high_dimensional_latents_not_saturated(self, tmp_path):
"""
Verify that high-dimensional realistic latents don't saturate eigenvalues.
This test simulates real FLUX training data:
- High dimension (16×64×64 = 65536)
- Varied content (different variance in different regions)
- Realistic magnitude (VAE output scale)
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
# Different channels have different patterns (realistic for VAE)
for c in range(16):
# Some channels have gradients
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
# Some channels have patterns
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
# Some channels are more uniform
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation (different "subjects")
latent = latent * (1.0 + i * 0.2)
# Add realistic VAE-like noise/variation
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are NOT all saturated at 1.0
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical: eigenvalues should NOT all be 1.0
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# FAIL if too many eigenvalues are saturated at 1.0
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"This indicates the normalization bug - raw eigenvalues are not being "
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same value."
)
# Mean should be in reasonable range (not all 1.0)
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
"""
Test that datasets with more variance produce more diverse eigenvalues.
This ensures the normalization preserves relative information.
"""
# Create two preprocessors with different data variance
results = {}
for variance_scale in [0.5, 2.0]:
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(15):
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
# Create varied patterns
for c in range(16):
for h in range(32):
for w in range(32):
latent[c, h, w] = (
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = []
for i in range(15):
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
eigvals.extend(ev[ev > 1e-6])
results[variance_scale] = {
'mean': np.mean(eigvals),
'std': np.std(eigvals),
'range': (np.min(eigvals), np.max(eigvals))
}
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
# Both should have diversity (not saturated)
for scale in [0.5, 2.0]:
assert results[scale]['std'] > 0.1, (
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -6,8 +6,6 @@ Ensures that gradients propagate correctly through both fast and slow paths.
import pytest
import torch
import tempfile
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation

View File

@@ -4,7 +4,6 @@ Test comparing interpolation vs pad/truncate for CDC preprocessing.
This test quantifies the difference between the two approaches.
"""
import numpy as np
import pytest
import torch
import torch.nn.functional as F
@@ -89,16 +88,16 @@ class TestInterpolationComparison:
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(f" BUT the intermediate representation is corrupted with zeros!")
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(" BUT the intermediate representation is corrupted with zeros!")
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
@@ -151,7 +150,7 @@ class TestInterpolationComparison:
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print(f"\nGradient smoothness (lower is smoother):")
print("\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")

View File

@@ -0,0 +1,268 @@
"""
Performance benchmarking for CDC Flow Matching implementation.
This module tests the computational overhead and noise injection properties
of the CDC-FM preprocessing pipeline.
"""
import time
import tempfile
import torch
import numpy as np
import pytest
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPerformance:
"""
Performance and Noise Injection Verification Tests for CDC Flow Matching
These tests validate the computational performance and noise injection properties
of the CDC-FM preprocessing pipeline across different latent sizes.
Key Verification Points:
1. Computational efficiency for various latent dimensions
2. Noise injection statistical properties
3. Eigenvector and eigenvalue characteristics
"""
@pytest.fixture(params=[
(3, 32, 32), # Small latent: typical for compact representations
(3, 64, 64), # Medium latent: standard feature maps
(3, 128, 128) # Large latent: high-resolution feature spaces
])
def latent_sizes(self, request):
"""
Parametrized fixture generating test cases for different latent sizes.
Rationale:
- Tests robustness across various computational scales
- Ensures consistent behavior from compact to large representations
- Identifies potential dimensionality-related performance bottlenecks
"""
return request.param
def test_computational_overhead(self, latent_sizes):
"""
Measure computational overhead of CDC preprocessing across latent sizes.
Performance Verification Objectives:
1. Verify preprocessing time scales predictably with input dimensions
2. Ensure adaptive k-neighbors works efficiently
3. Validate computational overhead remains within acceptable bounds
Performance Metrics:
- Total preprocessing time
- Per-sample processing time
- Computational complexity indicators
Args:
latent_sizes (tuple): Latent dimensions (C, H, W) to benchmark
"""
# Tuned preprocessing configuration
preprocessor = CDCPreprocessor(
k_neighbors=256, # Comprehensive neighborhood exploration
d_cdc=8, # Geometric embedding dimensionality
debug=True, # Enable detailed performance logging
adaptive_k=True # Dynamic neighborhood size adjustment
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42) # Consistent random generation
# Generate representative latent batch
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
# Precision timing of preprocessing
start_time = time.perf_counter()
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with traceable metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'perf_test_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Calculate precise preprocessing metrics
end_time = time.perf_counter()
preprocessing_time = end_time - start_time
per_sample_time = preprocessing_time / batch_size
# Performance reporting and assertions
input_volume = np.prod(latent_sizes)
time_complexity_indicator = preprocessing_time / input_volume
print(f"\nPerformance Breakdown:")
print(f" Latent Size: {latent_sizes}")
print(f" Total Samples: {batch_size}")
print(f" Input Volume: {input_volume}")
print(f" Total Time: {preprocessing_time:.4f} seconds")
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
# Adaptive thresholds based on input dimensions
max_total_time = 10.0 # Base threshold
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
# Different time complexity thresholds for different latent sizes
max_time_complexity = (
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
1e-4 # Standard latents
)
# Performance assertions with informative error messages
assert preprocessing_time < max_total_time, (
f"Total preprocessing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Total Time: {preprocessing_time:.4f} seconds\n"
f" Threshold: {max_total_time} seconds"
)
assert per_sample_time < max_per_sample_time, (
f"Per-sample processing time exceeded threshold!\n"
f" Latent Size: {latent_sizes}\n"
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
f" Threshold: {max_per_sample_time} seconds"
)
# More adaptable time complexity check
assert time_complexity_indicator < max_time_complexity, (
f"Time complexity scaling exceeded expectations!\n"
f" Latent Size: {latent_sizes}\n"
f" Input Volume: {input_volume}\n"
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
f" Threshold: {max_time_complexity} seconds/voxel"
)
def test_noise_distribution(self, latent_sizes):
"""
Verify CDC noise injection quality and properties.
Based on test plan objectives:
1. CDC noise is actually being generated (not all Gaussian fallback)
2. Eigenvalues are valid (non-negative, bounded)
3. CDC components are finite and usable for noise generation
Args:
latent_sizes (tuple): Latent dimensions (C, H, W)
"""
# Preprocessing configuration
preprocessor = CDCPreprocessor(
k_neighbors=16, # Reduced to match batch size
d_cdc=8,
gamma=1.0,
debug=True,
adaptive_k=True
)
# Set a fixed random seed for reproducibility
torch.manual_seed(42)
# Generate batch of latents
batch_size = 32
latents = torch.randn(batch_size, *latent_sizes)
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
# Add latents with metadata
for i, latent in enumerate(latents):
preprocessor.add_latent(
latent,
global_idx=i,
metadata={'image_key': f'noise_dist_image_{i}'}
)
# Compute CDC results
cdc_path = preprocessor.compute_all(tmp_file.name)
# Analyze noise properties
dataset = GammaBDataset(cdc_path)
# Track samples that used CDC vs Gaussian fallback
cdc_samples = 0
gaussian_samples = 0
eigenvalue_stats = {
'min': float('inf'),
'max': float('-inf'),
'mean': 0.0,
'sum': 0.0
}
# Verify each sample's CDC components
for i in range(batch_size):
image_key = f'noise_dist_image_{i}'
# Get eigenvectors and eigenvalues
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
# Skip zero eigenvectors (fallback case)
if torch.all(eigvecs[0] == 0):
gaussian_samples += 1
continue
# Get the top d_cdc eigenvectors and eigenvalues
top_eigvecs = eigvecs[0] # (d_cdc, d)
top_eigvals = eigvals[0] # (d_cdc,)
# Basic validity checks
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
# Update statistics
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
eigenvalue_stats['sum'] += top_eigvals.sum().item()
cdc_samples += 1
# Compute mean eigenvalue across all CDC samples
if cdc_samples > 0:
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
# Print final statistics
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
print(f" CDC samples: {cdc_samples}/{batch_size}")
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
# Assertions based on plan objectives
# 1. CDC noise should be generated for most samples
assert cdc_samples > 0, "No samples used CDC noise injection"
assert gaussian_samples < batch_size // 2, (
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
)
# 2. Eigenvalues should be valid (non-negative and bounded)
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
# 3. Mean eigenvalue should be reasonable (not degenerate)
assert eigenvalue_stats['mean'] > 0.05, (
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
"suggests degenerate CDC components"
)
def pytest_configure(config):
"""
Configure performance benchmarking markers
"""
config.addinivalue_line(
"markers",
"performance: mark test to verify CDC-FM computational performance"
)
config.addinivalue_line(
"markers",
"noise_distribution: mark test to verify noise injection properties"
)

View File

@@ -0,0 +1,237 @@
"""
Tests to validate the CDC rescaling recommendations from paper review.
These tests check:
1. Gamma parameter interaction with rescaling
2. Spatial adaptivity of eigenvalue scaling
3. Verification of fixed vs adaptive rescaling behavior
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestGammaRescalingInteraction:
"""Test that gamma parameter works correctly with eigenvalue rescaling"""
def test_gamma_scales_eigenvalues_correctly(self, tmp_path):
"""Verify gamma multiplier is applied correctly after rescaling"""
# Create two preprocessors with different gamma values
gamma_values = [0.5, 1.0, 2.0]
eigenvalue_results = {}
for gamma in gamma_values:
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu"
)
# Add identical deterministic data for all runs
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_gamma_{gamma}.safetensors"
preprocessor.compute_all(save_path=output_path)
# Extract eigenvalues
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = f.get_tensor("eigenvalues/test_image_0").numpy()
eigenvalue_results[gamma] = eigvals
# With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound
# Gamma 0.5: max eigenvalue should be ~0.5
# Gamma 1.0: max eigenvalue should be ~1.0
# Gamma 2.0: max eigenvalue should be ~2.0
max_0p5 = np.max(eigenvalue_results[0.5])
max_1p0 = np.max(eigenvalue_results[1.0])
max_2p0 = np.max(eigenvalue_results[2.0])
assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}"
assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}"
assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}"
# All should have min of 1e-3 (clamp lower bound)
assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3
assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3
assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3
print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}")
print(f"✓ Gamma 1.0 max: {max_1p0:.4f}")
print(f"✓ Gamma 2.0 max: {max_2p0:.4f}")
def test_large_gamma_maintains_reasonable_scale(self, tmp_path):
"""Verify that large gamma values don't cause eigenvalue explosion"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu"
)
for i in range(10):
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_large_gamma.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
max_eigval = np.max(all_eigvals)
mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6])
# With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0
# But they should still be reasonable (not exploding)
assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma"
assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma"
print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}")
class TestSpatialAdaptivityOfRescaling:
"""Test spatial variation in eigenvalue scaling"""
def test_eigenvalues_vary_spatially(self, tmp_path):
"""Verify eigenvalues differ across spatially separated clusters"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create two distinct clusters in latent space
# Cluster 1: Tight cluster (low variance) - deterministic spread
for i in range(10):
latent = torch.zeros(16, 4, 4)
# Small variation around 0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
# Cluster 2: Loose cluster (high variance) - deterministic spread
for i in range(10, 20):
latent = torch.ones(16, 4, 4) * 5.0
# Large variation around 5.0
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_spatial_variation.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
# Get eigenvalues from both clusters
cluster1_eigvals = []
cluster2_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster1_eigvals.append(np.max(eigvals))
for i in range(10, 20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
cluster2_eigvals.append(np.max(eigvals))
cluster1_mean = np.mean(cluster1_eigvals)
cluster2_mean = np.mean(cluster2_eigvals)
print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}")
print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}")
# With fixed target_scale rescaling, eigenvalues should be similar
# despite different local geometry
# This demonstrates the limitation of fixed rescaling
ratio = cluster2_mean / (cluster1_mean + 1e-10)
print(f"✓ Ratio (loose/tight): {ratio:.2f}")
# Both should be rescaled to similar magnitude (~0.1 due to target_scale)
assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range"
assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range"
class TestFixedVsAdaptiveRescaling:
"""Compare current fixed rescaling vs paper's adaptive approach"""
def test_current_rescaling_is_uniform(self, tmp_path):
"""Demonstrate that current rescaling produces uniform eigenvalue scales"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Create samples with varying local density - deterministic
for i in range(20):
latent = torch.zeros(16, 4, 4)
# Some samples clustered, some isolated
if i < 10:
# Dense cluster around origin
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05
else:
# Isolated points - larger offset
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_uniform_rescaling.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
max_eigenvalues = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
vals = eigvals[eigvals > 1e-6]
if vals.size: # at least one valid eigen-value
max_eigenvalues.append(vals.max())
if not max_eigenvalues: # safeguard against empty list
pytest.skip("no valid eigen-values found")
max_eigenvalues = np.array(max_eigenvalues)
# Check coefficient of variation (std / mean)
cv = max_eigenvalues.std() / max_eigenvalues.mean()
print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]")
print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}")
print(f"✓ Coefficient of variation: {cv:.4f}")
# With clamping, eigenvalues should have relatively low variation
assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping"
# Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0])
assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -5,10 +5,8 @@ These tests focus on CDC-FM specific functionality without importing
the full training infrastructure that has problematic dependencies.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
from safetensors.torch import save_file

View File

@@ -7,7 +7,6 @@ Ensures that duplicate warnings for the same sample are not logged repeatedly.
import pytest
import torch
import logging
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples