From 8089cb6925eeb6828fc49494dc59c3cf60a03276 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 11 Oct 2025 17:17:09 -0400 Subject: [PATCH] Improve dimension mismatch warning for CDC Flow Matching - Add explicit warning and tracking for multiple unique latent shapes - Simplify test imports by removing unused modules - Minor formatting improvements in print statements - Ensure log messages provide clear context about dimension mismatches --- library/cdc_fm.py | 11 + tests/library/test_cdc_adaptive_k.py | 2 - tests/library/test_cdc_advanced.py | 183 ++++++++++++ tests/library/test_cdc_dimension_handling.py | 146 ++++++++++ .../library/test_cdc_eigenvalue_real_data.py | 164 +++++++++++ tests/library/test_cdc_gradient_flow.py | 2 - .../test_cdc_interpolation_comparison.py | 11 +- tests/library/test_cdc_performance.py | 268 ++++++++++++++++++ .../test_cdc_rescaling_recommendations.py | 237 ++++++++++++++++ tests/library/test_cdc_standalone.py | 2 - tests/library/test_cdc_warning_throttling.py | 1 - 11 files changed, 1014 insertions(+), 13 deletions(-) create mode 100644 tests/library/test_cdc_advanced.py create mode 100644 tests/library/test_cdc_dimension_handling.py create mode 100644 tests/library/test_cdc_eigenvalue_real_data.py create mode 100644 tests/library/test_cdc_performance.py create mode 100644 tests/library/test_cdc_rescaling_recommendations.py diff --git a/library/cdc_fm.py b/library/cdc_fm.py index f4678f46..10b00864 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -354,9 +354,11 @@ class LatentBatcher: Dict mapping exact_shape -> list of samples with that shape """ batches = {} + shapes = set() for sample in self.samples: shape_key = sample.shape + shapes.add(shape_key) # Group by exact shape only - no aspect ratio grouping or resizing if shape_key not in batches: @@ -364,6 +366,15 @@ class LatentBatcher: batches[shape_key].append(sample) + # If more than one unique shape, log a warning + if len(shapes) > 1: + logger.warning( + "Dimension mismatch: %d unique shapes detected. " + "Shapes: %s. Using Gaussian fallback for these samples.", + len(shapes), + shapes + ) + return batches def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str: diff --git a/tests/library/test_cdc_adaptive_k.py b/tests/library/test_cdc_adaptive_k.py index aaa050f0..f5de5fac 100644 --- a/tests/library/test_cdc_adaptive_k.py +++ b/tests/library/test_cdc_adaptive_k.py @@ -6,8 +6,6 @@ Verifies that adaptive k properly adjusts based on bucket sizes. import pytest import torch -import numpy as np -from pathlib import Path from library.cdc_fm import CDCPreprocessor, GammaBDataset diff --git a/tests/library/test_cdc_advanced.py b/tests/library/test_cdc_advanced.py new file mode 100644 index 00000000..e2a43ea4 --- /dev/null +++ b/tests/library/test_cdc_advanced.py @@ -0,0 +1,183 @@ +import torch +from typing import Union + + +class MockGammaBDataset: + """ + Mock implementation of GammaBDataset for testing gradient flow + """ + def __init__(self, *args, **kwargs): + """ + Simple initialization that doesn't require file loading + """ + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def compute_sigma_t_x( + self, + eigenvectors: torch.Tensor, + eigenvalues: torch.Tensor, + x: torch.Tensor, + t: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Simplified implementation of compute_sigma_t_x for testing + """ + # Store original shape to restore later + orig_shape = x.shape + + # Flatten x if it's 4D + if x.dim() == 4: + B, C, H, W = x.shape + x = x.reshape(B, -1) # (B, C*H*W) + + if not isinstance(t, torch.Tensor): + t = torch.tensor(t, device=x.device, dtype=x.dtype) + + # Validate dimensions + assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch" + assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch" + + # Early return for t=0 with gradient preservation + if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad: + return x.reshape(orig_shape) + + # Compute Σ_t @ x + # V^T x + Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x) + + # sqrt(λ) * V^T x + sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10)) + sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x + + # V @ (sqrt(λ) * V^T x) + gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x) + + # Interpolate between original and noisy latent + result = (1 - t) * x + t * gamma_sqrt_x + + # Restore original shape + result = result.reshape(orig_shape) + + return result + +class TestCDCAdvanced: + def setup_method(self): + """Prepare consistent test environment""" + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def test_gradient_flow_preservation(self): + """ + Verify that gradient flow is preserved even for near-zero time steps + with learnable time embeddings + """ + # Set random seed for reproducibility + torch.manual_seed(42) + + # Create a learnable time embedding with small initial value + t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32) + + # Generate mock latent and CDC components + batch_size, latent_dim = 4, 64 + latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) + + # Create mock eigenvectors and eigenvalues + eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) + eigenvalues = torch.rand(batch_size, 8, device=self.device) + + # Ensure eigenvectors and eigenvalues are meaningful + eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) + eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) + + # Use the mock dataset + mock_dataset = MockGammaBDataset() + + # Compute noisy latent with gradient tracking + noisy_latent = mock_dataset.compute_sigma_t_x( + eigenvectors, + eigenvalues, + latent, + t + ) + + # Compute a dummy loss to check gradient flow + loss = noisy_latent.sum() + + # Compute gradients + loss.backward() + + # Assertions to verify gradient flow + assert t.grad is not None, "Time embedding gradient should be computed" + assert latent.grad is not None, "Input latent gradient should be computed" + + # Check gradient magnitudes are non-zero + t_grad_magnitude = torch.abs(t.grad).sum() + latent_grad_magnitude = torch.abs(latent.grad).sum() + + assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}" + assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}" + + # Optional: Print gradient details for debugging + print(f"Time embedding gradient magnitude: {t_grad_magnitude}") + print(f"Latent gradient magnitude: {latent_grad_magnitude}") + + def test_gradient_flow_with_different_time_steps(self): + """ + Verify gradient flow across different time step values + """ + # Test time steps + time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0] + + for time_val in time_steps: + # Create a learnable time embedding + t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32) + + # Generate mock latent and CDC components + batch_size, latent_dim = 4, 64 + latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) + + # Create mock eigenvectors and eigenvalues + eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) + eigenvalues = torch.rand(batch_size, 8, device=self.device) + + # Ensure eigenvectors and eigenvalues are meaningful + eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) + eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) + + # Use the mock dataset + mock_dataset = MockGammaBDataset() + + # Compute noisy latent with gradient tracking + noisy_latent = mock_dataset.compute_sigma_t_x( + eigenvectors, + eigenvalues, + latent, + t + ) + + # Compute a dummy loss to check gradient flow + loss = noisy_latent.sum() + + # Compute gradients + loss.backward() + + # Assertions to verify gradient flow + t_grad_magnitude = torch.abs(t.grad).sum() + latent_grad_magnitude = torch.abs(latent.grad).sum() + + assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}" + assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}" + + # Reset gradients for next iteration + if t.grad is not None: + t.grad.zero_() + if latent.grad is not None: + latent.grad.zero_() + +def pytest_configure(config): + """ + Add custom markers for CDC-FM tests + """ + config.addinivalue_line( + "markers", + "gradient_flow: mark test to verify gradient preservation in CDC Flow Matching" + ) \ No newline at end of file diff --git a/tests/library/test_cdc_dimension_handling.py b/tests/library/test_cdc_dimension_handling.py new file mode 100644 index 00000000..147a1d7e --- /dev/null +++ b/tests/library/test_cdc_dimension_handling.py @@ -0,0 +1,146 @@ +""" +Test CDC-FM dimension handling and fallback mechanisms. + +This module tests the behavior of the CDC Flow Matching implementation +when encountering latents with different dimensions. +""" + +import torch +import logging +import tempfile + +from library.cdc_fm import CDCPreprocessor, GammaBDataset + +class TestDimensionHandling: + def setup_method(self): + """Prepare consistent test environment""" + self.logger = logging.getLogger(__name__) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def test_mixed_dimension_fallback(self): + """ + Verify that preprocessor falls back to standard noise for mixed-dimension batches + """ + # Prepare preprocessor with debug mode + preprocessor = CDCPreprocessor(debug=True) + + # Different-sized latents (3D: channels, height, width) + latents = [ + torch.randn(3, 32, 64), # First latent: 3x32x64 + torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension) + ] + + # Use a mock handler to capture log messages + from library.cdc_fm import logger + + log_messages = [] + class LogCapture(logging.Handler): + def emit(self, record): + log_messages.append(record.getMessage()) + + # Temporarily add a capture handler + capture_handler = LogCapture() + logger.addHandler(capture_handler) + + try: + # Try adding mixed-dimension latents + with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: + for i, latent in enumerate(latents): + preprocessor.add_latent( + latent, + global_idx=i, + metadata={'image_key': f'test_mixed_image_{i}'} + ) + + try: + cdc_path = preprocessor.compute_all(tmp_file.name) + except ValueError as e: + # If implementation raises ValueError, that's acceptable + assert "Dimension mismatch" in str(e) + return + + # Check for dimension-related log messages + dimension_warnings = [ + msg for msg in log_messages + if "dimension mismatch" in msg.lower() + ] + assert len(dimension_warnings) > 0, "No dimension-related warnings were logged" + + # Load results and verify fallback + dataset = GammaBDataset(cdc_path) + + finally: + # Remove the capture handler + logger.removeHandler(capture_handler) + + # Check metadata about samples with/without CDC + assert dataset.num_samples == len(latents), "All samples should be processed" + + def test_adaptive_k_with_dimension_constraints(self): + """ + Test adaptive k-neighbors behavior with dimension constraints + """ + # Prepare preprocessor with adaptive k and small bucket size + preprocessor = CDCPreprocessor( + adaptive_k=True, + min_bucket_size=5, + debug=True + ) + + # Generate latents with similar but not identical dimensions + base_latent = torch.randn(3, 32, 64) + similar_latents = [ + base_latent, + torch.randn(3, 32, 65), # Slightly different dimension + torch.randn(3, 32, 66) # Another slightly different dimension + ] + + # Use a mock handler to capture log messages + from library.cdc_fm import logger + + log_messages = [] + class LogCapture(logging.Handler): + def emit(self, record): + log_messages.append(record.getMessage()) + + # Temporarily add a capture handler + capture_handler = LogCapture() + logger.addHandler(capture_handler) + + try: + with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: + # Add similar latents + for i, latent in enumerate(similar_latents): + preprocessor.add_latent( + latent, + global_idx=i, + metadata={'image_key': f'test_adaptive_k_image_{i}'} + ) + + cdc_path = preprocessor.compute_all(tmp_file.name) + + # Load results + dataset = GammaBDataset(cdc_path) + + # Verify samples processed + assert dataset.num_samples == len(similar_latents), "All samples should be processed" + + # Optional: Check warnings about dimension differences + dimension_warnings = [ + msg for msg in log_messages + if "dimension" in msg.lower() + ] + print(f"Dimension-related warnings: {dimension_warnings}") + + finally: + # Remove the capture handler + logger.removeHandler(capture_handler) + +def pytest_configure(config): + """ + Configure custom markers for dimension handling tests + """ + config.addinivalue_line( + "markers", + "dimension_handling: mark test for CDC-FM dimension mismatch scenarios" + ) \ No newline at end of file diff --git a/tests/library/test_cdc_eigenvalue_real_data.py b/tests/library/test_cdc_eigenvalue_real_data.py new file mode 100644 index 00000000..3202b37c --- /dev/null +++ b/tests/library/test_cdc_eigenvalue_real_data.py @@ -0,0 +1,164 @@ +""" +Tests using realistic high-dimensional data to catch scaling bugs. + +This test uses realistic VAE-like latents to ensure eigenvalue normalization +works correctly on real-world data. +""" + +import numpy as np +import pytest +import torch +from safetensors import safe_open + +from library.cdc_fm import CDCPreprocessor + + +class TestRealisticDataScaling: + """Test eigenvalue scaling with realistic high-dimensional data""" + + def test_high_dimensional_latents_not_saturated(self, tmp_path): + """ + Verify that high-dimensional realistic latents don't saturate eigenvalues. + + This test simulates real FLUX training data: + - High dimension (16×64×64 = 65536) + - Varied content (different variance in different regions) + - Realistic magnitude (VAE output scale) + """ + preprocessor = CDCPreprocessor( + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + ) + + # Create 20 samples with realistic varied structure + for i in range(20): + # High-dimensional latent like FLUX + latent = torch.zeros(16, 64, 64, dtype=torch.float32) + + # Create varied structure across the latent + # Different channels have different patterns (realistic for VAE) + for c in range(16): + # Some channels have gradients + if c < 4: + for h in range(64): + for w in range(64): + latent[c, h, w] = (h + w) / 128.0 + # Some channels have patterns + elif c < 8: + for h in range(64): + for w in range(64): + latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0) + # Some channels are more uniform + else: + latent[c, :, :] = c * 0.1 + + # Add per-sample variation (different "subjects") + latent = latent * (1.0 + i * 0.2) + + # Add realistic VAE-like noise/variation + latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3) + + metadata = {'image_key': f'test_image_{i}'} + + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / "test_realistic_gamma_b.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + # Verify eigenvalues are NOT all saturated at 1.0 + with safe_open(str(result_path), framework="pt", device="cpu") as f: + all_eigvals = [] + for i in range(20): + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + all_eigvals.extend(eigvals) + + all_eigvals = np.array(all_eigvals) + non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] + + # Critical: eigenvalues should NOT all be 1.0 + at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01) + total = len(non_zero_eigvals) + percent_at_max = (at_max / total * 100) if total > 0 else 0 + + print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") + print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}") + print(f"✓ Std: {np.std(non_zero_eigvals):.4f}") + print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)") + + # FAIL if too many eigenvalues are saturated at 1.0 + assert percent_at_max < 80, ( + f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! " + f"This indicates the normalization bug - raw eigenvalues are not being " + f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]" + ) + + # Should have good diversity + assert np.std(non_zero_eigvals) > 0.1, ( + f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. " + f"Should see diverse eigenvalues, not all the same value." + ) + + # Mean should be in reasonable range (not all 1.0) + mean_eigval = np.mean(non_zero_eigvals) + assert 0.05 < mean_eigval < 0.9, ( + f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. " + f"If mean ≈ 1.0, eigenvalues are saturated." + ) + + def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path): + """ + Test that datasets with more variance produce more diverse eigenvalues. + + This ensures the normalization preserves relative information. + """ + # Create two preprocessors with different data variance + results = {} + + for variance_scale in [0.5, 2.0]: + preprocessor = CDCPreprocessor( + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + ) + + for i in range(15): + latent = torch.zeros(16, 32, 32, dtype=torch.float32) + + # Create varied patterns + for c in range(16): + for h in range(32): + for w in range(32): + latent[c, h, w] = ( + np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale + ) + + metadata = {'image_key': f'test_image_{i}'} + + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / f"test_variance_{variance_scale}.safetensors" + preprocessor.compute_all(save_path=output_path) + + with safe_open(str(output_path), framework="pt", device="cpu") as f: + eigvals = [] + for i in range(15): + ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + eigvals.extend(ev[ev > 1e-6]) + + results[variance_scale] = { + 'mean': np.mean(eigvals), + 'std': np.std(eigvals), + 'range': (np.min(eigvals), np.max(eigvals)) + } + + print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}") + print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}") + + # Both should have diversity (not saturated) + for scale in [0.5, 2.0]: + assert results[scale]['std'] > 0.1, ( + f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_gradient_flow.py b/tests/library/test_cdc_gradient_flow.py index b0fd4cfa..a1fb515f 100644 --- a/tests/library/test_cdc_gradient_flow.py +++ b/tests/library/test_cdc_gradient_flow.py @@ -6,8 +6,6 @@ Ensures that gradients propagate correctly through both fast and slow paths. import pytest import torch -import tempfile -from pathlib import Path from library.cdc_fm import CDCPreprocessor, GammaBDataset from library.flux_train_utils import apply_cdc_noise_transformation diff --git a/tests/library/test_cdc_interpolation_comparison.py b/tests/library/test_cdc_interpolation_comparison.py index 9ad71eaf..46b2d8b2 100644 --- a/tests/library/test_cdc_interpolation_comparison.py +++ b/tests/library/test_cdc_interpolation_comparison.py @@ -4,7 +4,6 @@ Test comparing interpolation vs pad/truncate for CDC preprocessing. This test quantifies the difference between the two approaches. """ -import numpy as np import pytest import torch import torch.nn.functional as F @@ -89,16 +88,16 @@ class TestInterpolationComparison: print("\n" + "=" * 60) print("Reconstruction Error Comparison") print("=" * 60) - print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):") + print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):") print(f" Interpolation error: {interp_error_small:.6f}") print(f" Pad/truncate error: {pad_error_small:.6f}") if pad_error_small > 0: print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%") else: - print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)") - print(f" BUT the intermediate representation is corrupted with zeros!") + print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)") + print(" BUT the intermediate representation is corrupted with zeros!") - print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):") + print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):") print(f" Interpolation error: {interp_error_large:.6f}") print(f" Pad/truncate error: {truncate_error_large:.6f}") if truncate_error_large > 0: @@ -151,7 +150,7 @@ class TestInterpolationComparison: print("\n" + "=" * 60) print("Spatial Structure Preservation") print("=" * 60) - print(f"\nGradient smoothness (lower is smoother):") + print("\nGradient smoothness (lower is smoother):") print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}") print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}") diff --git a/tests/library/test_cdc_performance.py b/tests/library/test_cdc_performance.py new file mode 100644 index 00000000..8f63e6fe --- /dev/null +++ b/tests/library/test_cdc_performance.py @@ -0,0 +1,268 @@ +""" +Performance benchmarking for CDC Flow Matching implementation. + +This module tests the computational overhead and noise injection properties +of the CDC-FM preprocessing pipeline. +""" + +import time +import tempfile +import torch +import numpy as np +import pytest + +from library.cdc_fm import CDCPreprocessor, GammaBDataset + +class TestCDCPerformance: + """ + Performance and Noise Injection Verification Tests for CDC Flow Matching + + These tests validate the computational performance and noise injection properties + of the CDC-FM preprocessing pipeline across different latent sizes. + + Key Verification Points: + 1. Computational efficiency for various latent dimensions + 2. Noise injection statistical properties + 3. Eigenvector and eigenvalue characteristics + """ + + @pytest.fixture(params=[ + (3, 32, 32), # Small latent: typical for compact representations + (3, 64, 64), # Medium latent: standard feature maps + (3, 128, 128) # Large latent: high-resolution feature spaces + ]) + def latent_sizes(self, request): + """ + Parametrized fixture generating test cases for different latent sizes. + + Rationale: + - Tests robustness across various computational scales + - Ensures consistent behavior from compact to large representations + - Identifies potential dimensionality-related performance bottlenecks + """ + return request.param + + def test_computational_overhead(self, latent_sizes): + """ + Measure computational overhead of CDC preprocessing across latent sizes. + + Performance Verification Objectives: + 1. Verify preprocessing time scales predictably with input dimensions + 2. Ensure adaptive k-neighbors works efficiently + 3. Validate computational overhead remains within acceptable bounds + + Performance Metrics: + - Total preprocessing time + - Per-sample processing time + - Computational complexity indicators + + Args: + latent_sizes (tuple): Latent dimensions (C, H, W) to benchmark + """ + # Tuned preprocessing configuration + preprocessor = CDCPreprocessor( + k_neighbors=256, # Comprehensive neighborhood exploration + d_cdc=8, # Geometric embedding dimensionality + debug=True, # Enable detailed performance logging + adaptive_k=True # Dynamic neighborhood size adjustment + ) + + # Set a fixed random seed for reproducibility + torch.manual_seed(42) # Consistent random generation + + # Generate representative latent batch + batch_size = 32 + latents = torch.randn(batch_size, *latent_sizes) + + # Precision timing of preprocessing + start_time = time.perf_counter() + + with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: + # Add latents with traceable metadata + for i, latent in enumerate(latents): + preprocessor.add_latent( + latent, + global_idx=i, + metadata={'image_key': f'perf_test_image_{i}'} + ) + + # Compute CDC results + cdc_path = preprocessor.compute_all(tmp_file.name) + + # Calculate precise preprocessing metrics + end_time = time.perf_counter() + preprocessing_time = end_time - start_time + per_sample_time = preprocessing_time / batch_size + + # Performance reporting and assertions + input_volume = np.prod(latent_sizes) + time_complexity_indicator = preprocessing_time / input_volume + + print(f"\nPerformance Breakdown:") + print(f" Latent Size: {latent_sizes}") + print(f" Total Samples: {batch_size}") + print(f" Input Volume: {input_volume}") + print(f" Total Time: {preprocessing_time:.4f} seconds") + print(f" Per Sample Time: {per_sample_time:.6f} seconds") + print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel") + + # Adaptive thresholds based on input dimensions + max_total_time = 10.0 # Base threshold + max_per_sample_time = 2.0 # Per-sample time threshold (more lenient) + + # Different time complexity thresholds for different latent sizes + max_time_complexity = ( + 1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents + 1e-4 # Standard latents + ) + + # Performance assertions with informative error messages + assert preprocessing_time < max_total_time, ( + f"Total preprocessing time exceeded threshold!\n" + f" Latent Size: {latent_sizes}\n" + f" Total Time: {preprocessing_time:.4f} seconds\n" + f" Threshold: {max_total_time} seconds" + ) + + assert per_sample_time < max_per_sample_time, ( + f"Per-sample processing time exceeded threshold!\n" + f" Latent Size: {latent_sizes}\n" + f" Per Sample Time: {per_sample_time:.6f} seconds\n" + f" Threshold: {max_per_sample_time} seconds" + ) + + # More adaptable time complexity check + assert time_complexity_indicator < max_time_complexity, ( + f"Time complexity scaling exceeded expectations!\n" + f" Latent Size: {latent_sizes}\n" + f" Input Volume: {input_volume}\n" + f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n" + f" Threshold: {max_time_complexity} seconds/voxel" + ) + + def test_noise_distribution(self, latent_sizes): + """ + Verify CDC noise injection quality and properties. + + Based on test plan objectives: + 1. CDC noise is actually being generated (not all Gaussian fallback) + 2. Eigenvalues are valid (non-negative, bounded) + 3. CDC components are finite and usable for noise generation + + Args: + latent_sizes (tuple): Latent dimensions (C, H, W) + """ + # Preprocessing configuration + preprocessor = CDCPreprocessor( + k_neighbors=16, # Reduced to match batch size + d_cdc=8, + gamma=1.0, + debug=True, + adaptive_k=True + ) + + # Set a fixed random seed for reproducibility + torch.manual_seed(42) + + # Generate batch of latents + batch_size = 32 + latents = torch.randn(batch_size, *latent_sizes) + + with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: + # Add latents with metadata + for i, latent in enumerate(latents): + preprocessor.add_latent( + latent, + global_idx=i, + metadata={'image_key': f'noise_dist_image_{i}'} + ) + + # Compute CDC results + cdc_path = preprocessor.compute_all(tmp_file.name) + + # Analyze noise properties + dataset = GammaBDataset(cdc_path) + + # Track samples that used CDC vs Gaussian fallback + cdc_samples = 0 + gaussian_samples = 0 + eigenvalue_stats = { + 'min': float('inf'), + 'max': float('-inf'), + 'mean': 0.0, + 'sum': 0.0 + } + + # Verify each sample's CDC components + for i in range(batch_size): + image_key = f'noise_dist_image_{i}' + + # Get eigenvectors and eigenvalues + eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key]) + + # Skip zero eigenvectors (fallback case) + if torch.all(eigvecs[0] == 0): + gaussian_samples += 1 + continue + + # Get the top d_cdc eigenvectors and eigenvalues + top_eigvecs = eigvecs[0] # (d_cdc, d) + top_eigvals = eigvals[0] # (d_cdc,) + + # Basic validity checks + assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}" + assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}" + + # Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM) + assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}" + assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}" + + # Update statistics + eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item()) + eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item()) + eigenvalue_stats['sum'] += top_eigvals.sum().item() + + cdc_samples += 1 + + # Compute mean eigenvalue across all CDC samples + if cdc_samples > 0: + eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc + + # Print final statistics + print(f"\nNoise Distribution Results for latent size {latent_sizes}:") + print(f" CDC samples: {cdc_samples}/{batch_size}") + print(f" Gaussian fallback: {gaussian_samples}/{batch_size}") + print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}") + print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}") + print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}") + + # Assertions based on plan objectives + + # 1. CDC noise should be generated for most samples + assert cdc_samples > 0, "No samples used CDC noise injection" + assert gaussian_samples < batch_size // 2, ( + f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}" + ) + + # 2. Eigenvalues should be valid (non-negative and bounded) + assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative" + assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0" + + # 3. Mean eigenvalue should be reasonable (not degenerate) + assert eigenvalue_stats['mean'] > 0.05, ( + f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), " + "suggests degenerate CDC components" + ) + +def pytest_configure(config): + """ + Configure performance benchmarking markers + """ + config.addinivalue_line( + "markers", + "performance: mark test to verify CDC-FM computational performance" + ) + config.addinivalue_line( + "markers", + "noise_distribution: mark test to verify noise injection properties" + ) \ No newline at end of file diff --git a/tests/library/test_cdc_rescaling_recommendations.py b/tests/library/test_cdc_rescaling_recommendations.py new file mode 100644 index 00000000..75e8c3fb --- /dev/null +++ b/tests/library/test_cdc_rescaling_recommendations.py @@ -0,0 +1,237 @@ +""" +Tests to validate the CDC rescaling recommendations from paper review. + +These tests check: +1. Gamma parameter interaction with rescaling +2. Spatial adaptivity of eigenvalue scaling +3. Verification of fixed vs adaptive rescaling behavior +""" + +import numpy as np +import pytest +import torch +from safetensors import safe_open + +from library.cdc_fm import CDCPreprocessor + + +class TestGammaRescalingInteraction: + """Test that gamma parameter works correctly with eigenvalue rescaling""" + + def test_gamma_scales_eigenvalues_correctly(self, tmp_path): + """Verify gamma multiplier is applied correctly after rescaling""" + # Create two preprocessors with different gamma values + gamma_values = [0.5, 1.0, 2.0] + eigenvalue_results = {} + + for gamma in gamma_values: + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu" + ) + + # Add identical deterministic data for all runs + for i in range(10): + latent = torch.zeros(16, 4, 4, dtype=torch.float32) + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1 + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / f"test_gamma_{gamma}.safetensors" + preprocessor.compute_all(save_path=output_path) + + # Extract eigenvalues + with safe_open(str(output_path), framework="pt", device="cpu") as f: + eigvals = f.get_tensor("eigenvalues/test_image_0").numpy() + eigenvalue_results[gamma] = eigvals + + # With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound + # Gamma 0.5: max eigenvalue should be ~0.5 + # Gamma 1.0: max eigenvalue should be ~1.0 + # Gamma 2.0: max eigenvalue should be ~2.0 + + max_0p5 = np.max(eigenvalue_results[0.5]) + max_1p0 = np.max(eigenvalue_results[1.0]) + max_2p0 = np.max(eigenvalue_results[2.0]) + + assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}" + assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}" + assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}" + + # All should have min of 1e-3 (clamp lower bound) + assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3 + assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3 + assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3 + + print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}") + print(f"✓ Gamma 1.0 max: {max_1p0:.4f}") + print(f"✓ Gamma 2.0 max: {max_2p0:.4f}") + + def test_large_gamma_maintains_reasonable_scale(self, tmp_path): + """Verify that large gamma values don't cause eigenvalue explosion""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu" + ) + + for i in range(10): + latent = torch.zeros(16, 4, 4, dtype=torch.float32) + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15 + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / "test_large_gamma.safetensors" + preprocessor.compute_all(save_path=output_path) + + with safe_open(str(output_path), framework="pt", device="cpu") as f: + all_eigvals = [] + for i in range(10): + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + all_eigvals.extend(eigvals) + + max_eigval = np.max(all_eigvals) + mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6]) + + # With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0 + # But they should still be reasonable (not exploding) + assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma" + assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma" + + print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}") + + +class TestSpatialAdaptivityOfRescaling: + """Test spatial variation in eigenvalue scaling""" + + def test_eigenvalues_vary_spatially(self, tmp_path): + """Verify eigenvalues differ across spatially separated clusters""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + # Create two distinct clusters in latent space + # Cluster 1: Tight cluster (low variance) - deterministic spread + for i in range(10): + latent = torch.zeros(16, 4, 4) + # Small variation around 0 + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01 + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + # Cluster 2: Loose cluster (high variance) - deterministic spread + for i in range(10, 20): + latent = torch.ones(16, 4, 4) * 5.0 + # Large variation around 5.0 + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2 + metadata = {'image_key': f'test_image_{i}'} + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / "test_spatial_variation.safetensors" + preprocessor.compute_all(save_path=output_path) + + with safe_open(str(output_path), framework="pt", device="cpu") as f: + # Get eigenvalues from both clusters + cluster1_eigvals = [] + cluster2_eigvals = [] + + for i in range(10): + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + cluster1_eigvals.append(np.max(eigvals)) + + for i in range(10, 20): + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + cluster2_eigvals.append(np.max(eigvals)) + + cluster1_mean = np.mean(cluster1_eigvals) + cluster2_mean = np.mean(cluster2_eigvals) + + print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}") + print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}") + + # With fixed target_scale rescaling, eigenvalues should be similar + # despite different local geometry + # This demonstrates the limitation of fixed rescaling + ratio = cluster2_mean / (cluster1_mean + 1e-10) + print(f"✓ Ratio (loose/tight): {ratio:.2f}") + + # Both should be rescaled to similar magnitude (~0.1 due to target_scale) + assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range" + assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range" + + +class TestFixedVsAdaptiveRescaling: + """Compare current fixed rescaling vs paper's adaptive approach""" + + def test_current_rescaling_is_uniform(self, tmp_path): + """Demonstrate that current rescaling produces uniform eigenvalue scales""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + # Create samples with varying local density - deterministic + for i in range(20): + latent = torch.zeros(16, 4, 4) + # Some samples clustered, some isolated + if i < 10: + # Dense cluster around origin + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05 + else: + # Isolated points - larger offset + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0 + + metadata = {'image_key': f'test_image_{i}'} + + + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + + output_path = tmp_path / "test_uniform_rescaling.safetensors" + preprocessor.compute_all(save_path=output_path) + + with safe_open(str(output_path), framework="pt", device="cpu") as f: + max_eigenvalues = [] + for i in range(20): + eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() + vals = eigvals[eigvals > 1e-6] + if vals.size: # at least one valid eigen-value + max_eigenvalues.append(vals.max()) + + if not max_eigenvalues: # safeguard against empty list + pytest.skip("no valid eigen-values found") + + max_eigenvalues = np.array(max_eigenvalues) + + # Check coefficient of variation (std / mean) + cv = max_eigenvalues.std() / max_eigenvalues.mean() + + print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]") + print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}") + print(f"✓ Coefficient of variation: {cv:.4f}") + + # With clamping, eigenvalues should have relatively low variation + assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping" + # Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0]) + assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py index e0943dc4..c7fb2d85 100644 --- a/tests/library/test_cdc_standalone.py +++ b/tests/library/test_cdc_standalone.py @@ -5,10 +5,8 @@ These tests focus on CDC-FM specific functionality without importing the full training infrastructure that has problematic dependencies. """ -import tempfile from pathlib import Path -import numpy as np import pytest import torch from safetensors.torch import save_file diff --git a/tests/library/test_cdc_warning_throttling.py b/tests/library/test_cdc_warning_throttling.py index 41d1b050..d8cba614 100644 --- a/tests/library/test_cdc_warning_throttling.py +++ b/tests/library/test_cdc_warning_throttling.py @@ -7,7 +7,6 @@ Ensures that duplicate warnings for the same sample are not logged repeatedly. import pytest import torch import logging -from pathlib import Path from library.cdc_fm import CDCPreprocessor, GammaBDataset from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples