diff --git a/tests/library/test_cdc_adaptive_k.py b/tests/library/test_cdc_adaptive_k.py deleted file mode 100644 index f5de5fac..00000000 --- a/tests/library/test_cdc_adaptive_k.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Test adaptive k_neighbors functionality in CDC-FM. - -Verifies that adaptive k properly adjusts based on bucket sizes. -""" - -import pytest -import torch - -from library.cdc_fm import CDCPreprocessor, GammaBDataset - - -class TestAdaptiveK: - """Test adaptive k_neighbors behavior""" - - @pytest.fixture - def temp_cache_path(self, tmp_path): - """Create temporary cache path""" - return tmp_path / "adaptive_k_test.safetensors" - - def test_fixed_k_skips_small_buckets(self, temp_cache_path): - """ - Test that fixed k mode skips buckets with < k_neighbors samples. - """ - preprocessor = CDCPreprocessor( - k_neighbors=32, - k_bandwidth=8, - d_cdc=4, - gamma=1.0, - device='cpu', - debug=False, - adaptive_k=False # Fixed mode - ) - - # Add 10 samples (< k=32, should be skipped) - shape = (4, 16, 16) - for i in range(10): - latent = torch.randn(*shape, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=i, - shape=shape, - metadata={'image_key': f'test_{i}'} - ) - - preprocessor.compute_all(temp_cache_path) - - # Load and verify zeros (Gaussian fallback) - dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu') - eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu') - - # Should be all zeros (fallback) - assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6) - assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6) - - def test_adaptive_k_uses_available_neighbors(self, temp_cache_path): - """ - Test that adaptive k mode uses k=bucket_size-1 for small buckets. - """ - preprocessor = CDCPreprocessor( - k_neighbors=32, - k_bandwidth=8, - d_cdc=4, - gamma=1.0, - device='cpu', - debug=False, - adaptive_k=True, - min_bucket_size=8 - ) - - # Add 20 samples (< k=32, should use k=19) - shape = (4, 16, 16) - for i in range(20): - latent = torch.randn(*shape, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=i, - shape=shape, - metadata={'image_key': f'test_{i}'} - ) - - preprocessor.compute_all(temp_cache_path) - - # Load and verify non-zero (CDC computed) - dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu') - eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu') - - # Should NOT be all zeros (CDC was computed) - assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6) - assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6) - - def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path): - """ - Test that adaptive k mode skips buckets below min_bucket_size. - """ - preprocessor = CDCPreprocessor( - k_neighbors=32, - k_bandwidth=8, - d_cdc=4, - gamma=1.0, - device='cpu', - debug=False, - adaptive_k=True, - min_bucket_size=16 - ) - - # Add 10 samples (< min_bucket_size=16, should be skipped) - shape = (4, 16, 16) - for i in range(10): - latent = torch.randn(*shape, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=i, - shape=shape, - metadata={'image_key': f'test_{i}'} - ) - - preprocessor.compute_all(temp_cache_path) - - # Load and verify zeros (skipped due to min_bucket_size) - dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu') - eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu') - - # Should be all zeros (skipped) - assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6) - assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6) - - def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path): - """ - Test adaptive k with multiple buckets of different sizes. - """ - preprocessor = CDCPreprocessor( - k_neighbors=32, - k_bandwidth=8, - d_cdc=4, - gamma=1.0, - device='cpu', - debug=False, - adaptive_k=True, - min_bucket_size=8 - ) - - # Bucket 1: 10 samples (adaptive k=9) - for i in range(10): - latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=i, - shape=(4, 16, 16), - metadata={'image_key': f'small_{i}'} - ) - - # Bucket 2: 40 samples (full k=32) - for i in range(40): - latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=100+i, - shape=(4, 32, 32), - metadata={'image_key': f'large_{i}'} - ) - - # Bucket 3: 5 samples (< min=8, skipped) - for i in range(5): - latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=200+i, - shape=(4, 8, 8), - metadata={'image_key': f'tiny_{i}'} - ) - - preprocessor.compute_all(temp_cache_path) - dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu') - - # Bucket 1: Should have CDC (non-zero) - eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu') - assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6) - - # Bucket 2: Should have CDC (non-zero) - eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu') - assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6) - - # Bucket 3: Should be skipped (zeros) - eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu') - assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6) - assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6) - - def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path): - """ - Test that adaptive k uses full k_neighbors when bucket is large enough. - """ - preprocessor = CDCPreprocessor( - k_neighbors=16, - k_bandwidth=4, - d_cdc=4, - gamma=1.0, - device='cpu', - debug=False, - adaptive_k=True, - min_bucket_size=8 - ) - - # Add 50 samples (> k=16, should use full k=16) - shape = (4, 16, 16) - for i in range(50): - latent = torch.randn(*shape, dtype=torch.float32).numpy() - preprocessor.add_latent( - latent=latent, - global_idx=i, - shape=shape, - metadata={'image_key': f'test_{i}'} - ) - - preprocessor.compute_all(temp_cache_path) - - # Load and verify CDC was computed - dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu') - eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu') - - # Should have non-zero eigenvalues - assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6) - # Eigenvalues should be positive - assert (eigvals >= 0).all() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/library/test_cdc_device_consistency.py b/tests/library/test_cdc_device_consistency.py deleted file mode 100644 index 5d4af544..00000000 --- a/tests/library/test_cdc_device_consistency.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Test device consistency handling in CDC noise transformation. - -Ensures that device mismatches are handled gracefully. -""" - -import pytest -import torch -import logging - -from library.cdc_fm import CDCPreprocessor, GammaBDataset -from library.flux_train_utils import apply_cdc_noise_transformation - - -class TestDeviceConsistency: - """Test device consistency validation""" - - @pytest.fixture - def cdc_cache(self, tmp_path): - """Create a test CDC cache""" - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - shape = (16, 32, 32) - for i in range(10): - latent = torch.randn(*shape, dtype=torch.float32) - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) - - cache_path = tmp_path / "test_device.safetensors" - preprocessor.compute_all(save_path=cache_path) - return cache_path - - def test_matching_devices_no_warning(self, cdc_cache, caplog): - """ - Test that no warnings are emitted when devices match. - """ - dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu") - - shape = (16, 32, 32) - noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") - timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - image_keys = ['test_image_0', 'test_image_1'] - - with caplog.at_level(logging.WARNING): - caplog.clear() - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # No device mismatch warnings - device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()] - assert len(device_warnings) == 0, "Should not warn when devices match" - - def test_device_mismatch_warning_and_transfer(self, cdc_cache, caplog): - """ - Test that device mismatch is detected, warned, and handled. - - This simulates the case where noise is on one device but CDC matrices - are requested for another device. - """ - dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu") - - shape = (16, 32, 32) - # Create noise on CPU - noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") - timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - image_keys = ['test_image_0', 'test_image_1'] - - # But request CDC matrices for a different device string - # (In practice this would be "cuda" vs "cpu", but we simulate with string comparison) - with caplog.at_level(logging.WARNING): - caplog.clear() - - # Use a different device specification to trigger the check - # We'll use "cpu" vs "cpu:0" as an example of string mismatch - result = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" # Same actual device, consistent string - ) - - # Should complete without errors - assert result is not None - assert result.shape == noise.shape - - def test_transformation_works_after_device_transfer(self, cdc_cache): - """ - Test that CDC transformation produces valid output even if devices differ. - - The function should handle device transfer gracefully. - """ - dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu") - - shape = (16, 32, 32) - noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True) - timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") - image_keys = ['test_image_0', 'test_image_1'] - - result = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Verify output is valid - assert result.shape == noise.shape - assert result.device == noise.device - assert result.requires_grad # Gradients should still work - assert not torch.isnan(result).any() - assert not torch.isinf(result).any() - - # Verify gradients flow - loss = result.sum() - loss.backward() - assert noise.grad is not None - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_dimension_handling.py b/tests/library/test_cdc_dimension_handling.py deleted file mode 100644 index 147a1d7e..00000000 --- a/tests/library/test_cdc_dimension_handling.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -Test CDC-FM dimension handling and fallback mechanisms. - -This module tests the behavior of the CDC Flow Matching implementation -when encountering latents with different dimensions. -""" - -import torch -import logging -import tempfile - -from library.cdc_fm import CDCPreprocessor, GammaBDataset - -class TestDimensionHandling: - def setup_method(self): - """Prepare consistent test environment""" - self.logger = logging.getLogger(__name__) - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - def test_mixed_dimension_fallback(self): - """ - Verify that preprocessor falls back to standard noise for mixed-dimension batches - """ - # Prepare preprocessor with debug mode - preprocessor = CDCPreprocessor(debug=True) - - # Different-sized latents (3D: channels, height, width) - latents = [ - torch.randn(3, 32, 64), # First latent: 3x32x64 - torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension) - ] - - # Use a mock handler to capture log messages - from library.cdc_fm import logger - - log_messages = [] - class LogCapture(logging.Handler): - def emit(self, record): - log_messages.append(record.getMessage()) - - # Temporarily add a capture handler - capture_handler = LogCapture() - logger.addHandler(capture_handler) - - try: - # Try adding mixed-dimension latents - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - for i, latent in enumerate(latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'test_mixed_image_{i}'} - ) - - try: - cdc_path = preprocessor.compute_all(tmp_file.name) - except ValueError as e: - # If implementation raises ValueError, that's acceptable - assert "Dimension mismatch" in str(e) - return - - # Check for dimension-related log messages - dimension_warnings = [ - msg for msg in log_messages - if "dimension mismatch" in msg.lower() - ] - assert len(dimension_warnings) > 0, "No dimension-related warnings were logged" - - # Load results and verify fallback - dataset = GammaBDataset(cdc_path) - - finally: - # Remove the capture handler - logger.removeHandler(capture_handler) - - # Check metadata about samples with/without CDC - assert dataset.num_samples == len(latents), "All samples should be processed" - - def test_adaptive_k_with_dimension_constraints(self): - """ - Test adaptive k-neighbors behavior with dimension constraints - """ - # Prepare preprocessor with adaptive k and small bucket size - preprocessor = CDCPreprocessor( - adaptive_k=True, - min_bucket_size=5, - debug=True - ) - - # Generate latents with similar but not identical dimensions - base_latent = torch.randn(3, 32, 64) - similar_latents = [ - base_latent, - torch.randn(3, 32, 65), # Slightly different dimension - torch.randn(3, 32, 66) # Another slightly different dimension - ] - - # Use a mock handler to capture log messages - from library.cdc_fm import logger - - log_messages = [] - class LogCapture(logging.Handler): - def emit(self, record): - log_messages.append(record.getMessage()) - - # Temporarily add a capture handler - capture_handler = LogCapture() - logger.addHandler(capture_handler) - - try: - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - # Add similar latents - for i, latent in enumerate(similar_latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'test_adaptive_k_image_{i}'} - ) - - cdc_path = preprocessor.compute_all(tmp_file.name) - - # Load results - dataset = GammaBDataset(cdc_path) - - # Verify samples processed - assert dataset.num_samples == len(similar_latents), "All samples should be processed" - - # Optional: Check warnings about dimension differences - dimension_warnings = [ - msg for msg in log_messages - if "dimension" in msg.lower() - ] - print(f"Dimension-related warnings: {dimension_warnings}") - - finally: - # Remove the capture handler - logger.removeHandler(capture_handler) - -def pytest_configure(config): - """ - Configure custom markers for dimension handling tests - """ - config.addinivalue_line( - "markers", - "dimension_handling: mark test for CDC-FM dimension mismatch scenarios" - ) \ No newline at end of file diff --git a/tests/library/test_cdc_dimension_handling_and_warnings.py b/tests/library/test_cdc_dimension_handling_and_warnings.py deleted file mode 100644 index 2f88f10c..00000000 --- a/tests/library/test_cdc_dimension_handling_and_warnings.py +++ /dev/null @@ -1,310 +0,0 @@ -""" -Comprehensive CDC Dimension Handling and Warning Tests - -This module tests: -1. Dimension mismatch detection and fallback mechanisms -2. Warning throttling for shape mismatches -3. Adaptive k-neighbors behavior with dimension constraints -""" - -import pytest -import torch -import logging -import tempfile - -from library.cdc_fm import CDCPreprocessor, GammaBDataset -from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples - - -class TestDimensionHandlingAndWarnings: - """ - Comprehensive testing of dimension handling, noise injection, and warning systems - """ - - @pytest.fixture(autouse=True) - def clear_warned_samples(self): - """Clear the warned samples set before each test""" - _cdc_warned_samples.clear() - yield - _cdc_warned_samples.clear() - - def test_mixed_dimension_fallback(self): - """ - Verify that preprocessor falls back to standard noise for mixed-dimension batches - """ - # Prepare preprocessor with debug mode - preprocessor = CDCPreprocessor(debug=True) - - # Different-sized latents (3D: channels, height, width) - latents = [ - torch.randn(3, 32, 64), # First latent: 3x32x64 - torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension) - ] - - # Use a mock handler to capture log messages - from library.cdc_fm import logger - - log_messages = [] - class LogCapture(logging.Handler): - def emit(self, record): - log_messages.append(record.getMessage()) - - # Temporarily add a capture handler - capture_handler = LogCapture() - logger.addHandler(capture_handler) - - try: - # Try adding mixed-dimension latents - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - for i, latent in enumerate(latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'test_mixed_image_{i}'} - ) - - try: - cdc_path = preprocessor.compute_all(tmp_file.name) - except ValueError as e: - # If implementation raises ValueError, that's acceptable - assert "Dimension mismatch" in str(e) - return - - # Check for dimension-related log messages - dimension_warnings = [ - msg for msg in log_messages - if "dimension mismatch" in msg.lower() - ] - assert len(dimension_warnings) > 0, "No dimension-related warnings were logged" - - # Load results and verify fallback - dataset = GammaBDataset(cdc_path) - - finally: - # Remove the capture handler - logger.removeHandler(capture_handler) - - # Check metadata about samples with/without CDC - assert dataset.num_samples == len(latents), "All samples should be processed" - - def test_adaptive_k_with_dimension_constraints(self): - """ - Test adaptive k-neighbors behavior with dimension constraints - """ - # Prepare preprocessor with adaptive k and small bucket size - preprocessor = CDCPreprocessor( - adaptive_k=True, - min_bucket_size=5, - debug=True - ) - - # Generate latents with similar but not identical dimensions - base_latent = torch.randn(3, 32, 64) - similar_latents = [ - base_latent, - torch.randn(3, 32, 65), # Slightly different dimension - torch.randn(3, 32, 66) # Another slightly different dimension - ] - - # Use a mock handler to capture log messages - from library.cdc_fm import logger - - log_messages = [] - class LogCapture(logging.Handler): - def emit(self, record): - log_messages.append(record.getMessage()) - - # Temporarily add a capture handler - capture_handler = LogCapture() - logger.addHandler(capture_handler) - - try: - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - # Add similar latents - for i, latent in enumerate(similar_latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'test_adaptive_k_image_{i}'} - ) - - cdc_path = preprocessor.compute_all(tmp_file.name) - - # Load results - dataset = GammaBDataset(cdc_path) - - # Verify samples processed - assert dataset.num_samples == len(similar_latents), "All samples should be processed" - - # Optional: Check warnings about dimension differences - dimension_warnings = [ - msg for msg in log_messages - if "dimension" in msg.lower() - ] - print(f"Dimension-related warnings: {dimension_warnings}") - - finally: - # Remove the capture handler - logger.removeHandler(capture_handler) - - def test_warning_only_logged_once_per_sample(self, caplog): - """ - Test that shape mismatch warning is only logged once per sample. - - Even if the same sample appears in multiple batches, only warn once. - """ - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create cache with one specific shape - preprocessed_shape = (16, 32, 32) - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - for i in range(10): - latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata) - - cdc_path = preprocessor.compute_all(save_path=tmp_file.name) - - dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - # Use different shape at runtime to trigger mismatch - runtime_shape = (16, 64, 64) - timesteps = torch.tensor([100.0], dtype=torch.float32) - image_keys = ['test_image_0'] # Same sample - - # First call - should warn - with caplog.at_level(logging.WARNING): - caplog.clear() - noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32) - _ = apply_cdc_noise_transformation( - noise=noise1, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have exactly one warning - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 1, "First call should produce exactly one warning" - assert "CDC shape mismatch" in warnings[0].message - - # Second call with same sample - should NOT warn - with caplog.at_level(logging.WARNING): - caplog.clear() - noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32) - _ = apply_cdc_noise_transformation( - noise=noise2, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have NO warnings - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 0, "Second call with same sample should not warn" - - def test_different_samples_each_get_one_warning(self, caplog): - """ - Test that different samples each get their own warning. - - Each unique sample should be warned about once. - """ - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create cache with specific shape - preprocessed_shape = (16, 32, 32) - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - for i in range(10): - latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata) - - cdc_path = preprocessor.compute_all(save_path=tmp_file.name) - - dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - runtime_shape = (16, 64, 64) - timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32) - - # First batch: samples 0, 1, 2 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have 3 warnings (one per sample) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 3, "Should warn for each of the 3 samples" - - # Second batch: same samples 0, 1, 2 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have NO warnings (already warned) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 0, "Should not warn again for same samples" - - # Third batch: new samples 3, 4 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(2, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_3', 'test_image_4'] - timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32) - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have 2 warnings (new samples) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 2, "Should warn for each of the 2 new samples" - - -def pytest_configure(config): - """ - Configure custom markers for dimension handling and warning tests - """ - config.addinivalue_line( - "markers", - "dimension_handling: mark test for CDC-FM dimension mismatch scenarios" - ) - config.addinivalue_line( - "markers", - "warning_throttling: mark test for CDC-FM warning suppression" - ) - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/tests/library/test_cdc_eigenvalue_real_data.py b/tests/library/test_cdc_eigenvalue_real_data.py deleted file mode 100644 index 3202b37c..00000000 --- a/tests/library/test_cdc_eigenvalue_real_data.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Tests using realistic high-dimensional data to catch scaling bugs. - -This test uses realistic VAE-like latents to ensure eigenvalue normalization -works correctly on real-world data. -""" - -import numpy as np -import pytest -import torch -from safetensors import safe_open - -from library.cdc_fm import CDCPreprocessor - - -class TestRealisticDataScaling: - """Test eigenvalue scaling with realistic high-dimensional data""" - - def test_high_dimensional_latents_not_saturated(self, tmp_path): - """ - Verify that high-dimensional realistic latents don't saturate eigenvalues. - - This test simulates real FLUX training data: - - High dimension (16×64×64 = 65536) - - Varied content (different variance in different regions) - - Realistic magnitude (VAE output scale) - """ - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create 20 samples with realistic varied structure - for i in range(20): - # High-dimensional latent like FLUX - latent = torch.zeros(16, 64, 64, dtype=torch.float32) - - # Create varied structure across the latent - # Different channels have different patterns (realistic for VAE) - for c in range(16): - # Some channels have gradients - if c < 4: - for h in range(64): - for w in range(64): - latent[c, h, w] = (h + w) / 128.0 - # Some channels have patterns - elif c < 8: - for h in range(64): - for w in range(64): - latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0) - # Some channels are more uniform - else: - latent[c, :, :] = c * 0.1 - - # Add per-sample variation (different "subjects") - latent = latent * (1.0 + i * 0.2) - - # Add realistic VAE-like noise/variation - latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3) - - metadata = {'image_key': f'test_image_{i}'} - - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_realistic_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - # Verify eigenvalues are NOT all saturated at 1.0 - with safe_open(str(result_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(20): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - all_eigvals = np.array(all_eigvals) - non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] - - # Critical: eigenvalues should NOT all be 1.0 - at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01) - total = len(non_zero_eigvals) - percent_at_max = (at_max / total * 100) if total > 0 else 0 - - print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") - print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}") - print(f"✓ Std: {np.std(non_zero_eigvals):.4f}") - print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)") - - # FAIL if too many eigenvalues are saturated at 1.0 - assert percent_at_max < 80, ( - f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! " - f"This indicates the normalization bug - raw eigenvalues are not being " - f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]" - ) - - # Should have good diversity - assert np.std(non_zero_eigvals) > 0.1, ( - f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. " - f"Should see diverse eigenvalues, not all the same value." - ) - - # Mean should be in reasonable range (not all 1.0) - mean_eigval = np.mean(non_zero_eigvals) - assert 0.05 < mean_eigval < 0.9, ( - f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. " - f"If mean ≈ 1.0, eigenvalues are saturated." - ) - - def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path): - """ - Test that datasets with more variance produce more diverse eigenvalues. - - This ensures the normalization preserves relative information. - """ - # Create two preprocessors with different data variance - results = {} - - for variance_scale in [0.5, 2.0]: - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - for i in range(15): - latent = torch.zeros(16, 32, 32, dtype=torch.float32) - - # Create varied patterns - for c in range(16): - for h in range(32): - for w in range(32): - latent[c, h, w] = ( - np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale - ) - - metadata = {'image_key': f'test_image_{i}'} - - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / f"test_variance_{variance_scale}.safetensors" - preprocessor.compute_all(save_path=output_path) - - with safe_open(str(output_path), framework="pt", device="cpu") as f: - eigvals = [] - for i in range(15): - ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - eigvals.extend(ev[ev > 1e-6]) - - results[variance_scale] = { - 'mean': np.mean(eigvals), - 'std': np.std(eigvals), - 'range': (np.min(eigvals), np.max(eigvals)) - } - - print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}") - print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}") - - # Both should have diversity (not saturated) - for scale in [0.5, 2.0]: - assert results[scale]['std'] > 0.1, ( - f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}" - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_eigenvalue_scaling.py b/tests/library/test_cdc_eigenvalue_scaling.py deleted file mode 100644 index 32f85d52..00000000 --- a/tests/library/test_cdc_eigenvalue_scaling.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -Tests to verify CDC eigenvalue scaling is correct. - -These tests ensure eigenvalues are properly scaled to prevent training loss explosion. -""" - -import numpy as np -import pytest -import torch -from safetensors import safe_open - -from library.cdc_fm import CDCPreprocessor - - -class TestEigenvalueScaling: - """Test that eigenvalues are properly scaled to reasonable ranges""" - - def test_eigenvalues_in_correct_range(self, tmp_path): - """Verify eigenvalues are scaled to ~0.01-1.0 range, not millions""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Add deterministic latents with structured patterns - for i in range(10): - # Create gradient pattern: values from 0 to 2.0 across spatial dims - latent = torch.zeros(16, 8, 8, dtype=torch.float32) - for h in range(8): - for w in range(8): - latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0] - # Add per-sample variation - latent = latent + i * 0.1 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - # Verify eigenvalues are in correct range - with safe_open(str(result_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - all_eigvals = np.array(all_eigvals) - - # Filter out zero eigenvalues (from padding when k < d_cdc) - non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] - - # Critical assertions for eigenvalue scale - assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)" - assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues" - assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large" - - # Check sqrt (used in noise) is reasonable - sqrt_max = np.sqrt(all_eigvals.max()) - assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion" - - print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") - print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}") - print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}") - print(f"✓ sqrt(max): {sqrt_max:.4f}") - - def test_eigenvalues_not_all_zero(self, tmp_path): - """Ensure eigenvalues are not all zero (indicating computation failure)""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - for i in range(10): - # Create deterministic pattern - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - with safe_open(str(result_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - all_eigvals = np.array(all_eigvals) - non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] - - # With clamping, eigenvalues will be in range [1e-3, gamma*1.0] - # Check that we have some non-zero eigenvalues - assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed" - - # Check they're in the expected clamped range - assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}" - assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}" - - print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}") - print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]") - print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}") - - def test_fp16_storage_no_overflow(self, tmp_path): - """Verify fp16 storage doesn't overflow (max fp16 = 65,504)""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - for i in range(10): - # Create deterministic pattern with higher magnitude - latent = torch.zeros(16, 8, 8, dtype=torch.float32) - for h in range(8): - for w in range(8): - latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0] - latent = latent + i * 0.3 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - with safe_open(str(result_path), framework="pt", device="cpu") as f: - # Check dtype is fp16 - eigvecs = f.get_tensor("eigenvectors/test_image_0") - eigvals = f.get_tensor("eigenvalues/test_image_0") - - assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}" - assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}" - - # Check no values near fp16 max (would indicate overflow) - FP16_MAX = 65504 - max_eigval = eigvals.max().item() - - assert max_eigval < 100, ( - f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. " - f"May indicate overflow (fp16 max = {FP16_MAX})" - ) - - print(f"\n✓ Storage dtype: {eigvals.dtype}") - print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)") - - def test_latent_magnitude_preserved(self, tmp_path): - """Verify latent magnitude is preserved (no unwanted normalization)""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - # Store original latents with deterministic patterns - original_latents = [] - for i in range(10): - # Create structured pattern with known magnitude - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5 - original_latents.append(latent.clone()) - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - # Compute original latent statistics - orig_std = torch.stack(original_latents).std().item() - - output_path = tmp_path / "test_gamma_b.safetensors" - preprocessor.compute_all(save_path=output_path) - - # The stored latents should preserve original magnitude - stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples]) - - # Should be similar to original (within 20% due to potential batching effects) - assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, ( - f"Stored latent std {stored_latents_std:.2f} differs too much from " - f"original {orig_std:.2f}. Latent magnitude was not preserved." - ) - - print(f"\n✓ Original latent std: {orig_std:.2f}") - print(f"✓ Stored latent std: {stored_latents_std:.2f}") - - -class TestTrainingLossScale: - """Test that eigenvalues produce reasonable loss magnitudes""" - - def test_noise_magnitude_reasonable(self, tmp_path): - """Verify CDC noise has reasonable magnitude for training""" - from library.cdc_fm import GammaBDataset - - # Create CDC cache with deterministic data - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - for i in range(10): - # Create deterministic pattern - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - cdc_path = preprocessor.compute_all(save_path=output_path) - - # Load and compute noise - gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - # Simulate training scenario with deterministic data - batch_size = 3 - latents = torch.zeros(batch_size, 16, 4, 4) - for b in range(batch_size): - for c in range(16): - for h in range(4): - for w in range(4): - latents[b, c, h, w] = (b + c + h + w) / 24.0 - t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps - image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] - - eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys) - noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t) - - # Check noise magnitude - noise_std = noise.std().item() - latent_std = latents.std().item() - - # Noise should be similar magnitude to input latents (within 10x) - ratio = noise_std / latent_std - assert 0.1 < ratio < 10.0, ( - f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) " - f"ratio {ratio:.2f} is too extreme. Will cause training instability." - ) - - # Simulated MSE loss should be reasonable - simulated_loss = torch.mean((noise - latents) ** 2).item() - assert simulated_loss < 100.0, ( - f"Simulated MSE loss {simulated_loss:.2f} is too high. " - f"Should be O(0.1-1.0) for stable training." - ) - - print(f"\n✓ Noise/latent ratio: {ratio:.2f}") - print(f"✓ Simulated MSE loss: {simulated_loss:.4f}") - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_eigenvalue_validation.py b/tests/library/test_cdc_eigenvalue_validation.py deleted file mode 100644 index 219b406c..00000000 --- a/tests/library/test_cdc_eigenvalue_validation.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Comprehensive CDC Eigenvalue Validation Tests - -These tests ensure that eigenvalue computation and scaling work correctly -across various scenarios, including: -- Scaling to reasonable ranges -- Handling high-dimensional data -- Preserving latent information -- Preventing computational artifacts -""" - -import numpy as np -import pytest -import torch -from safetensors import safe_open - -from library.cdc_fm import CDCPreprocessor, GammaBDataset - - -class TestEigenvalueScaling: - """Verify eigenvalue scaling and computational properties""" - - def test_eigenvalues_in_correct_range(self, tmp_path): - """ - Verify eigenvalues are scaled to ~0.01-1.0 range, not millions. - - Ensures: - - No numerical explosions - - Reasonable eigenvalue magnitudes - - Consistent scaling across samples - """ - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create deterministic latents with structured patterns - for i in range(10): - latent = torch.zeros(16, 8, 8, dtype=torch.float32) - for h in range(8): - for w in range(8): - latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0] - latent = latent + i * 0.1 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - # Verify eigenvalues are in correct range - with safe_open(str(result_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - all_eigvals = np.array(all_eigvals) - non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] - - # Critical assertions for eigenvalue scale - assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)" - assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues" - assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large" - - # Check sqrt (used in noise) is reasonable - sqrt_max = np.sqrt(all_eigvals.max()) - assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion" - - print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") - print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}") - print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}") - print(f"✓ sqrt(max): {sqrt_max:.4f}") - - def test_high_dimensional_latents_scaling(self, tmp_path): - """ - Verify scaling for high-dimensional realistic latents. - - Key scenarios: - - High-dimensional data (16×64×64) - - Varied channel structures - - Realistic VAE-like data - """ - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create 20 samples with realistic varied structure - for i in range(20): - # High-dimensional latent like FLUX - latent = torch.zeros(16, 64, 64, dtype=torch.float32) - - # Create varied structure across the latent - for c in range(16): - # Different patterns across channels - if c < 4: - for h in range(64): - for w in range(64): - latent[c, h, w] = (h + w) / 128.0 - elif c < 8: - for h in range(64): - for w in range(64): - latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0) - else: - latent[c, :, :] = c * 0.1 - - # Add per-sample variation - latent = latent * (1.0 + i * 0.2) - latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3) - - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_realistic_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) - - # Verify eigenvalues are not all saturated - with safe_open(str(result_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(20): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - all_eigvals = np.array(all_eigvals) - non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] - - at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01) - total = len(non_zero_eigvals) - percent_at_max = (at_max / total * 100) if total > 0 else 0 - - print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") - print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}") - print(f"✓ Std: {np.std(non_zero_eigvals):.4f}") - print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)") - - # Fail if too many eigenvalues are saturated - assert percent_at_max < 80, ( - f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! " - f"Raw eigenvalues not scaled before clamping. " - f"Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]" - ) - - # Should have good diversity - assert np.std(non_zero_eigvals) > 0.1, ( - f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. " - f"Should see diverse eigenvalues, not all the same." - ) - - # Mean should be in reasonable range - mean_eigval = np.mean(non_zero_eigvals) - assert 0.05 < mean_eigval < 0.9, ( - f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. " - f"If mean ≈ 1.0, eigenvalues are saturated." - ) - - def test_noise_magnitude_reasonable(self, tmp_path): - """ - Verify CDC noise has reasonable magnitude for training. - - Ensures noise: - - Has similar scale to input latents - - Won't destabilize training - - Preserves input variance - """ - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - for i in range(10): - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1 - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_gamma_b.safetensors" - cdc_path = preprocessor.compute_all(save_path=output_path) - - # Load and compute noise - gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - # Simulate training scenario with deterministic data - batch_size = 3 - latents = torch.zeros(batch_size, 16, 4, 4) - for b in range(batch_size): - for c in range(16): - for h in range(4): - for w in range(4): - latents[b, c, h, w] = (b + c + h + w) / 24.0 - t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps - image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] - - eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys) - noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t) - - # Check noise magnitude - noise_std = noise.std().item() - latent_std = latents.std().item() - - # Noise should be similar magnitude to input latents (within 10x) - ratio = noise_std / latent_std - assert 0.1 < ratio < 10.0, ( - f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) " - f"ratio {ratio:.2f} is too extreme. Will cause training instability." - ) - - # Simulated MSE loss should be reasonable - simulated_loss = torch.mean((noise - latents) ** 2).item() - assert simulated_loss < 100.0, ( - f"Simulated MSE loss {simulated_loss:.2f} is too high. " - f"Should be O(0.1-1.0) for stable training." - ) - - print(f"\n✓ Noise/latent ratio: {ratio:.2f}") - print(f"✓ Simulated MSE loss: {simulated_loss:.4f}") - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/tests/library/test_cdc_gradient_flow.py b/tests/library/test_cdc_gradient_flow.py deleted file mode 100644 index 3e8e4d74..00000000 --- a/tests/library/test_cdc_gradient_flow.py +++ /dev/null @@ -1,297 +0,0 @@ -""" -CDC Gradient Flow Verification Tests - -This module provides testing of: -1. Mock dataset gradient preservation -2. Real dataset gradient flow -3. Various time steps and computation paths -4. Fallback and edge case scenarios -""" - -import pytest -import torch - -from library.cdc_fm import CDCPreprocessor, GammaBDataset -from library.flux_train_utils import apply_cdc_noise_transformation - - -class MockGammaBDataset: - """ - Mock implementation of GammaBDataset for testing gradient flow - """ - def __init__(self, *args, **kwargs): - """ - Simple initialization that doesn't require file loading - """ - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - def compute_sigma_t_x( - self, - eigenvectors: torch.Tensor, - eigenvalues: torch.Tensor, - x: torch.Tensor, - t: torch.Tensor - ) -> torch.Tensor: - """ - Simplified implementation of compute_sigma_t_x for testing - """ - # Store original shape to restore later - orig_shape = x.shape - - # Flatten x if it's 4D - if x.dim() == 4: - B, C, H, W = x.shape - x = x.reshape(B, -1) # (B, C*H*W) - - # Validate dimensions - assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch" - assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch" - - # Early return for t=0 with gradient preservation - if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad: - return x.reshape(orig_shape) - - # Compute Σ_t @ x - # V^T x - Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x) - - # sqrt(λ) * V^T x - sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10)) - sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x - - # V @ (sqrt(λ) * V^T x) - gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x) - - # Interpolate between original and noisy latent - result = (1 - t) * x + t * gamma_sqrt_x - - # Restore original shape - result = result.reshape(orig_shape) - - return result - - -class TestCDCGradientFlow: - """ - Gradient flow testing for CDC noise transformations - """ - - def setup_method(self): - """Prepare consistent test environment""" - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - def test_mock_gradient_flow_near_zero_time_step(self): - """ - Verify gradient flow preservation for near-zero time steps - using mock dataset with learnable time embeddings - """ - # Set random seed for reproducibility - torch.manual_seed(42) - - # Create a learnable time embedding with small initial value - t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32) - - # Generate mock latent and CDC components - batch_size, latent_dim = 4, 64 - latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) - - # Create mock eigenvectors and eigenvalues - eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) - eigenvalues = torch.rand(batch_size, 8, device=self.device) - - # Ensure eigenvectors and eigenvalues are meaningful - eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) - eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) - - # Use the mock dataset - mock_dataset = MockGammaBDataset() - - # Compute noisy latent with gradient tracking - noisy_latent = mock_dataset.compute_sigma_t_x( - eigenvectors, - eigenvalues, - latent, - t - ) - - # Compute a dummy loss to check gradient flow - loss = noisy_latent.sum() - - # Compute gradients - loss.backward() - - # Assertions to verify gradient flow - assert t.grad is not None, "Time embedding gradient should be computed" - assert latent.grad is not None, "Input latent gradient should be computed" - - # Check gradient magnitudes are non-zero - t_grad_magnitude = torch.abs(t.grad).sum() - latent_grad_magnitude = torch.abs(latent.grad).sum() - - assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}" - assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}" - - def test_gradient_flow_with_multiple_time_steps(self): - """ - Verify gradient flow across different time step values - """ - # Test time steps - time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0] - - for time_val in time_steps: - # Create a learnable time embedding - t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32) - - # Generate mock latent and CDC components - batch_size, latent_dim = 4, 64 - latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) - - # Create mock eigenvectors and eigenvalues - eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) - eigenvalues = torch.rand(batch_size, 8, device=self.device) - - # Ensure eigenvectors and eigenvalues are meaningful - eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) - eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) - - # Use the mock dataset - mock_dataset = MockGammaBDataset() - - # Compute noisy latent with gradient tracking - noisy_latent = mock_dataset.compute_sigma_t_x( - eigenvectors, - eigenvalues, - latent, - t - ) - - # Compute a dummy loss to check gradient flow - loss = noisy_latent.sum() - - # Compute gradients - loss.backward() - - # Assertions to verify gradient flow - t_grad_magnitude = torch.abs(t.grad).sum() - latent_grad_magnitude = torch.abs(latent.grad).sum() - - assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}" - assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}" - - # Reset gradients for next iteration - t.grad.zero_() if t.grad is not None else None - latent.grad.zero_() if latent.grad is not None else None - - def test_gradient_flow_with_real_dataset(self, tmp_path): - """ - Test gradient flow with real CDC dataset - """ - # Create cache with uniform shapes - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - shape = (16, 32, 32) - for i in range(10): - latent = torch.randn(*shape, dtype=torch.float32) - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata) - - cache_path = tmp_path / "test_gradient.safetensors" - preprocessor.compute_all(save_path=cache_path) - dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu") - - # Prepare test noise - torch.manual_seed(42) - noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True) - timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32) - image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3'] - - # Apply CDC transformation - noise_out = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Verify gradient flow - assert noise_out.requires_grad, "Output should require gradients" - - loss = noise_out.sum() - loss.backward() - - assert noise.grad is not None, "Gradients should flow back to input noise" - assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN" - assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf" - assert (noise.grad != 0).any(), "Gradients should not be all zeros" - - def test_gradient_flow_with_fallback(self, tmp_path): - """ - Test gradient flow when using Gaussian fallback (shape mismatch) - - Ensures that cloned tensors maintain gradient flow correctly - even when shape mismatch triggers Gaussian noise - """ - # Create cache with one shape - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - preprocessed_shape = (16, 32, 32) - latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - metadata = {'image_key': 'test_image_0'} - preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata) - - cache_path = tmp_path / "test_fallback_gradient.safetensors" - preprocessor.compute_all(save_path=cache_path) - dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu") - - # Use different shape at runtime (will trigger fallback) - runtime_shape = (16, 64, 64) - noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True) - timesteps = torch.tensor([100.0], dtype=torch.float32) - image_keys = ['test_image_0'] - - # Apply transformation (should fallback to Gaussian for this sample) - noise_out = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Ensure gradients still flow through fallback path - assert noise_out.requires_grad, "Fallback output should require gradients" - - loss = noise_out.sum() - loss.backward() - - assert noise.grad is not None, "Gradients should flow even in fallback case" - assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN" - - -def pytest_configure(config): - """ - Configure custom markers for CDC gradient flow tests - """ - config.addinivalue_line( - "markers", - "gradient_flow: mark test to verify gradient preservation in CDC Flow Matching" - ) - config.addinivalue_line( - "markers", - "mock_dataset: mark test using mock dataset for simplified gradient testing" - ) - config.addinivalue_line( - "markers", - "real_dataset: mark test using real dataset for comprehensive gradient testing" - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/tests/library/test_cdc_hash_validation.py b/tests/library/test_cdc_hash_validation.py new file mode 100644 index 00000000..a6034c09 --- /dev/null +++ b/tests/library/test_cdc_hash_validation.py @@ -0,0 +1,157 @@ +""" +Test CDC config hash generation and cache invalidation +""" + +import pytest +import torch +from pathlib import Path + +from library.cdc_fm import CDCPreprocessor + + +class TestCDCConfigHash: + """ + Test that CDC config hash properly invalidates cache when dataset or parameters change + """ + + def test_same_config_produces_same_hash(self, tmp_path): + """ + Test that identical configurations produce identical hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash == preprocessor2.config_hash + + def test_different_dataset_dirs_produce_different_hash(self, tmp_path): + """ + Test that different dataset directories produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset2")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_k_neighbors_produces_different_hash(self, tmp_path): + """ + Test that different k_neighbors values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_d_cdc_produces_different_hash(self, tmp_path): + """ + Test that different d_cdc values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_gamma_produces_different_hash(self, tmp_path): + """ + Test that different gamma values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_multiple_dataset_dirs_order_independent(self, tmp_path): + """ + Test that dataset directory order doesn't affect hash (they are sorted) + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash == preprocessor2.config_hash + + def test_hash_length_is_8_chars(self, tmp_path): + """ + Test that hash is exactly 8 characters (hex) + """ + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert len(preprocessor.config_hash) == 8 + # Verify it's hex + int(preprocessor.config_hash, 16) # Should not raise + + def test_filename_includes_hash(self, tmp_path): + """ + Test that CDC filenames include the config hash + """ + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + latents_path = str(tmp_path / "image_0512x0768_flux.npz") + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash) + + # Should be: image_0512x0768_flux_cdc_.npz + expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz") + assert cdc_path == expected + + def test_backward_compatibility_no_hash(self, tmp_path): + """ + Test that get_cdc_npz_path works without hash (backward compatibility) + """ + latents_path = str(tmp_path / "image_0512x0768_flux.npz") + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None) + + # Should be: image_0512x0768_flux_cdc.npz (no hash suffix) + expected = str(tmp_path / "image_0512x0768_flux_cdc.npz") + assert cdc_path == expected + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_cdc_interpolation_comparison.py b/tests/library/test_cdc_interpolation_comparison.py deleted file mode 100644 index 46b2d8b2..00000000 --- a/tests/library/test_cdc_interpolation_comparison.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -Test comparing interpolation vs pad/truncate for CDC preprocessing. - -This test quantifies the difference between the two approaches. -""" - -import pytest -import torch -import torch.nn.functional as F - - -class TestInterpolationComparison: - """Compare interpolation vs pad/truncate""" - - def test_intermediate_representation_quality(self): - """Compare intermediate representation quality for CDC computation""" - # Create test latents with different sizes - deterministic - latent_small = torch.zeros(16, 4, 4) - for c in range(16): - for h in range(4): - for w in range(4): - latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0 - - latent_large = torch.zeros(16, 8, 8) - for c in range(16): - for h in range(8): - for w in range(8): - latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0 - - target_h, target_w = 6, 6 # Median size - - # Method 1: Interpolation - def interpolate_method(latent, target_h, target_w): - latent_input = latent.unsqueeze(0) # (1, C, H, W) - latent_resized = F.interpolate( - latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False - ) - # Resize back - C, H, W = latent.shape - latent_reconstructed = F.interpolate( - latent_resized, size=(H, W), mode='bilinear', align_corners=False - ) - error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item() - relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8) - return relative_error - - # Method 2: Pad/Truncate - def pad_truncate_method(latent, target_h, target_w): - C, H, W = latent.shape - latent_flat = latent.reshape(-1) - target_dim = C * target_h * target_w - current_dim = C * H * W - - if current_dim == target_dim: - latent_resized_flat = latent_flat - elif current_dim > target_dim: - # Truncate - latent_resized_flat = latent_flat[:target_dim] - else: - # Pad - latent_resized_flat = torch.zeros(target_dim) - latent_resized_flat[:current_dim] = latent_flat - - # Resize back - if current_dim == target_dim: - latent_reconstructed_flat = latent_resized_flat - elif current_dim > target_dim: - # Pad back - latent_reconstructed_flat = torch.zeros(current_dim) - latent_reconstructed_flat[:target_dim] = latent_resized_flat - else: - # Truncate back - latent_reconstructed_flat = latent_resized_flat[:current_dim] - - latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W) - error = torch.mean(torch.abs(latent_reconstructed - latent)).item() - relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8) - return relative_error - - # Compare for small latent (needs padding) - interp_error_small = interpolate_method(latent_small, target_h, target_w) - pad_error_small = pad_truncate_method(latent_small, target_h, target_w) - - # Compare for large latent (needs truncation) - interp_error_large = interpolate_method(latent_large, target_h, target_w) - truncate_error_large = pad_truncate_method(latent_large, target_h, target_w) - - print("\n" + "=" * 60) - print("Reconstruction Error Comparison") - print("=" * 60) - print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):") - print(f" Interpolation error: {interp_error_small:.6f}") - print(f" Pad/truncate error: {pad_error_small:.6f}") - if pad_error_small > 0: - print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%") - else: - print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)") - print(" BUT the intermediate representation is corrupted with zeros!") - - print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):") - print(f" Interpolation error: {interp_error_large:.6f}") - print(f" Pad/truncate error: {truncate_error_large:.6f}") - if truncate_error_large > 0: - print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%") - - # The key insight: Reconstruction error is NOT what matters for CDC! - # What matters is the INTERMEDIATE representation quality used for geometry estimation. - # Pad/truncate may have good reconstruction, but the intermediate is corrupted. - - print("\nKey insight: For CDC, intermediate representation quality matters,") - print("not reconstruction error. Interpolation preserves spatial structure.") - - # Verify interpolation errors are reasonable - assert interp_error_small < 1.0, "Interpolation should have reasonable error" - assert interp_error_large < 1.0, "Interpolation should have reasonable error" - - def test_spatial_structure_preservation(self): - """Test that interpolation preserves spatial structure better than pad/truncate""" - # Create a latent with clear spatial pattern (gradient) - C, H, W = 16, 4, 4 - latent = torch.zeros(C, H, W) - for i in range(H): - for j in range(W): - latent[:, i, j] = i * W + j # Gradient pattern - - target_h, target_w = 6, 6 - - # Interpolation - latent_input = latent.unsqueeze(0) - latent_interp = F.interpolate( - latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False - ).squeeze(0) - - # Pad/truncate - latent_flat = latent.reshape(-1) - target_dim = C * target_h * target_w - latent_padded = torch.zeros(target_dim) - latent_padded[:len(latent_flat)] = latent_flat - latent_pad = latent_padded.reshape(C, target_h, target_w) - - # Check gradient preservation - # For interpolation, adjacent pixels should have smooth gradients - grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean() - grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean() - - # For padding, there will be abrupt changes (gradient to zero) - grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean() - grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean() - - print("\n" + "=" * 60) - print("Spatial Structure Preservation") - print("=" * 60) - print("\nGradient smoothness (lower is smoother):") - print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}") - print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}") - - # Padding introduces larger gradients due to abrupt zeros - assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients" - assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients" - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_performance.py b/tests/library/test_cdc_performance.py deleted file mode 100644 index 1ebd0009..00000000 --- a/tests/library/test_cdc_performance.py +++ /dev/null @@ -1,412 +0,0 @@ -""" -Performance and Interpolation Tests for CDC Flow Matching - -This module provides testing of: -1. Computational overhead -2. Noise injection properties -3. Interpolation vs. pad/truncate methods -4. Spatial structure preservation -""" - -import pytest -import torch -import time -import tempfile -import numpy as np -import torch.nn.functional as F - -from library.cdc_fm import CDCPreprocessor, GammaBDataset - - -class TestCDCPerformanceAndInterpolation: - """ - Comprehensive performance testing for CDC Flow Matching - Covers computational efficiency, noise properties, and interpolation quality - """ - - @pytest.fixture(params=[ - (3, 32, 32), # Small latent: typical for compact representations - (3, 64, 64), # Medium latent: standard feature maps - (3, 128, 128) # Large latent: high-resolution feature spaces - ]) - def latent_sizes(self, request): - """ - Parametrized fixture generating test cases for different latent sizes. - - Rationale: - - Tests robustness across various computational scales - - Ensures consistent behavior from compact to large representations - - Identifies potential dimensionality-related performance bottlenecks - """ - return request.param - - def test_computational_overhead(self, latent_sizes): - """ - Measure computational overhead of CDC preprocessing across latent sizes. - - Performance Verification Objectives: - 1. Verify preprocessing time scales predictably with input dimensions - 2. Ensure adaptive k-neighbors works efficiently - 3. Validate computational overhead remains within acceptable bounds - - Performance Metrics: - - Total preprocessing time - - Per-sample processing time - - Computational complexity indicators - """ - # Tuned preprocessing configuration - preprocessor = CDCPreprocessor( - k_neighbors=256, # Comprehensive neighborhood exploration - d_cdc=8, # Geometric embedding dimensionality - debug=True, # Enable detailed performance logging - adaptive_k=True # Dynamic neighborhood size adjustment - ) - - # Set a fixed random seed for reproducibility - torch.manual_seed(42) # Consistent random generation - - # Generate representative latent batch - batch_size = 32 - latents = torch.randn(batch_size, *latent_sizes) - - # Precision timing of preprocessing - start_time = time.perf_counter() - - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - # Add latents with traceable metadata - for i, latent in enumerate(latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'perf_test_image_{i}'} - ) - - # Compute CDC results - cdc_path = preprocessor.compute_all(tmp_file.name) - - # Calculate precise preprocessing metrics - end_time = time.perf_counter() - preprocessing_time = end_time - start_time - per_sample_time = preprocessing_time / batch_size - - # Performance reporting and assertions - input_volume = np.prod(latent_sizes) - time_complexity_indicator = preprocessing_time / input_volume - - print(f"\nPerformance Breakdown:") - print(f" Latent Size: {latent_sizes}") - print(f" Total Samples: {batch_size}") - print(f" Input Volume: {input_volume}") - print(f" Total Time: {preprocessing_time:.4f} seconds") - print(f" Per Sample Time: {per_sample_time:.6f} seconds") - print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel") - - # Adaptive thresholds based on input dimensions - max_total_time = 10.0 # Base threshold - max_per_sample_time = 2.0 # Per-sample time threshold (more lenient) - - # Different time complexity thresholds for different latent sizes - max_time_complexity = ( - 1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents - 1e-4 # Standard latents - ) - - # Performance assertions with informative error messages - assert preprocessing_time < max_total_time, ( - f"Total preprocessing time exceeded threshold!\n" - f" Latent Size: {latent_sizes}\n" - f" Total Time: {preprocessing_time:.4f} seconds\n" - f" Threshold: {max_total_time} seconds" - ) - - assert per_sample_time < max_per_sample_time, ( - f"Per-sample processing time exceeded threshold!\n" - f" Latent Size: {latent_sizes}\n" - f" Per Sample Time: {per_sample_time:.6f} seconds\n" - f" Threshold: {max_per_sample_time} seconds" - ) - - # More adaptable time complexity check - assert time_complexity_indicator < max_time_complexity, ( - f"Time complexity scaling exceeded expectations!\n" - f" Latent Size: {latent_sizes}\n" - f" Input Volume: {input_volume}\n" - f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n" - f" Threshold: {max_time_complexity} seconds/voxel" - ) - - def test_noise_distribution(self, latent_sizes): - """ - Verify CDC noise injection quality and properties. - - Based on test plan objectives: - 1. CDC noise is actually being generated (not all Gaussian fallback) - 2. Eigenvalues are valid (non-negative, bounded) - 3. CDC components are finite and usable for noise generation - """ - preprocessor = CDCPreprocessor( - k_neighbors=16, # Reduced to match batch size - d_cdc=8, - gamma=1.0, - debug=True, - adaptive_k=True - ) - - # Set a fixed random seed for reproducibility - torch.manual_seed(42) - - # Generate batch of latents - batch_size = 32 - latents = torch.randn(batch_size, *latent_sizes) - - with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file: - # Add latents with metadata - for i, latent in enumerate(latents): - preprocessor.add_latent( - latent, - global_idx=i, - metadata={'image_key': f'noise_dist_image_{i}'} - ) - - # Compute CDC results - cdc_path = preprocessor.compute_all(tmp_file.name) - - # Analyze noise properties - dataset = GammaBDataset(cdc_path) - - # Track samples that used CDC vs Gaussian fallback - cdc_samples = 0 - gaussian_samples = 0 - eigenvalue_stats = { - 'min': float('inf'), - 'max': float('-inf'), - 'mean': 0.0, - 'sum': 0.0 - } - - # Verify each sample's CDC components - for i in range(batch_size): - image_key = f'noise_dist_image_{i}' - - # Get eigenvectors and eigenvalues - eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key]) - - # Skip zero eigenvectors (fallback case) - if torch.all(eigvecs[0] == 0): - gaussian_samples += 1 - continue - - # Get the top d_cdc eigenvectors and eigenvalues - top_eigvecs = eigvecs[0] # (d_cdc, d) - top_eigvals = eigvals[0] # (d_cdc,) - - # Basic validity checks - assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}" - assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}" - - # Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM) - assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}" - assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}" - - # Update statistics - eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item()) - eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item()) - eigenvalue_stats['sum'] += top_eigvals.sum().item() - - cdc_samples += 1 - - # Compute mean eigenvalue across all CDC samples - if cdc_samples > 0: - eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc - - # Print final statistics - print(f"\nNoise Distribution Results for latent size {latent_sizes}:") - print(f" CDC samples: {cdc_samples}/{batch_size}") - print(f" Gaussian fallback: {gaussian_samples}/{batch_size}") - print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}") - print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}") - print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}") - - # Assertions based on plan objectives - # 1. CDC noise should be generated for most samples - assert cdc_samples > 0, "No samples used CDC noise injection" - assert gaussian_samples < batch_size // 2, ( - f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}" - ) - - # 2. Eigenvalues should be valid (non-negative and bounded) - assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative" - assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0" - - # 3. Mean eigenvalue should be reasonable (not degenerate) - assert eigenvalue_stats['mean'] > 0.05, ( - f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), " - "suggests degenerate CDC components" - ) - - def test_interpolation_reconstruction(self): - """ - Compare interpolation vs pad/truncate reconstruction methods for CDC. - """ - # Create test latents with different sizes - deterministic - latent_small = torch.zeros(16, 4, 4) - for c in range(16): - for h in range(4): - for w in range(4): - latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0 - - latent_large = torch.zeros(16, 8, 8) - for c in range(16): - for h in range(8): - for w in range(8): - latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0 - - target_h, target_w = 6, 6 # Median size - - # Method 1: Interpolation - def interpolate_method(latent, target_h, target_w): - latent_input = latent.unsqueeze(0) # (1, C, H, W) - latent_resized = F.interpolate( - latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False - ) - # Resize back - C, H, W = latent.shape - latent_reconstructed = F.interpolate( - latent_resized, size=(H, W), mode='bilinear', align_corners=False - ) - error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item() - relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8) - return relative_error - - # Method 2: Pad/Truncate - def pad_truncate_method(latent, target_h, target_w): - C, H, W = latent.shape - latent_flat = latent.reshape(-1) - target_dim = C * target_h * target_w - current_dim = C * H * W - - if current_dim == target_dim: - latent_resized_flat = latent_flat - elif current_dim > target_dim: - # Truncate - latent_resized_flat = latent_flat[:target_dim] - else: - # Pad - latent_resized_flat = torch.zeros(target_dim) - latent_resized_flat[:current_dim] = latent_flat - - # Resize back - if current_dim == target_dim: - latent_reconstructed_flat = latent_resized_flat - elif current_dim > target_dim: - # Pad back - latent_reconstructed_flat = torch.zeros(current_dim) - latent_reconstructed_flat[:target_dim] = latent_resized_flat - else: - # Truncate back - latent_reconstructed_flat = latent_resized_flat[:current_dim] - - latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W) - error = torch.mean(torch.abs(latent_reconstructed - latent)).item() - relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8) - return relative_error - - # Compare for small latent (needs padding) - interp_error_small = interpolate_method(latent_small, target_h, target_w) - pad_error_small = pad_truncate_method(latent_small, target_h, target_w) - - # Compare for large latent (needs truncation) - interp_error_large = interpolate_method(latent_large, target_h, target_w) - truncate_error_large = pad_truncate_method(latent_large, target_h, target_w) - - print("\n" + "=" * 60) - print("Reconstruction Error Comparison") - print("=" * 60) - print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):") - print(f" Interpolation error: {interp_error_small:.6f}") - print(f" Pad/truncate error: {pad_error_small:.6f}") - if pad_error_small > 0: - print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%") - else: - print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)") - print(" BUT the intermediate representation is corrupted with zeros!") - - print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):") - print(f" Interpolation error: {interp_error_large:.6f}") - print(f" Pad/truncate error: {truncate_error_large:.6f}") - if truncate_error_large > 0: - print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%") - - print("\nKey insight: For CDC, intermediate representation quality matters,") - print("not reconstruction error. Interpolation preserves spatial structure.") - - # Verify interpolation errors are reasonable - assert interp_error_small < 1.0, "Interpolation should have reasonable error" - assert interp_error_large < 1.0, "Interpolation should have reasonable error" - - def test_spatial_structure_preservation(self): - """ - Test that interpolation preserves spatial structure better than pad/truncate. - """ - # Create a latent with clear spatial pattern (gradient) - C, H, W = 16, 4, 4 - latent = torch.zeros(C, H, W) - for i in range(H): - for j in range(W): - latent[:, i, j] = i * W + j # Gradient pattern - - target_h, target_w = 6, 6 - - # Interpolation - latent_input = latent.unsqueeze(0) - latent_interp = F.interpolate( - latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False - ).squeeze(0) - - # Pad/truncate - latent_flat = latent.reshape(-1) - target_dim = C * target_h * target_w - latent_padded = torch.zeros(target_dim) - latent_padded[:len(latent_flat)] = latent_flat - latent_pad = latent_padded.reshape(C, target_h, target_w) - - # Check gradient preservation - # For interpolation, adjacent pixels should have smooth gradients - grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean() - grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean() - - # For padding, there will be abrupt changes (gradient to zero) - grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean() - grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean() - - print("\n" + "=" * 60) - print("Spatial Structure Preservation") - print("=" * 60) - print("\nGradient smoothness (lower is smoother):") - print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}") - print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}") - - # Padding introduces larger gradients due to abrupt zeros - assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients" - assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients" - - -def pytest_configure(config): - """ - Configure performance benchmarking markers - """ - config.addinivalue_line( - "markers", - "performance: mark test to verify CDC-FM computational performance" - ) - config.addinivalue_line( - "markers", - "noise_distribution: mark test to verify noise injection properties" - ) - config.addinivalue_line( - "markers", - "interpolation: mark test to verify interpolation quality" - ) - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/tests/library/test_cdc_preprocessor.py b/tests/library/test_cdc_preprocessor.py index 63db6286..21005bab 100644 --- a/tests/library/test_cdc_preprocessor.py +++ b/tests/library/test_cdc_preprocessor.py @@ -29,7 +29,8 @@ class TestCDCPreprocessorIntegration: Test basic CDC preprocessing with small dataset """ preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash ) # Add 10 small latents @@ -51,8 +52,9 @@ class TestCDCPreprocessorIntegration: # Verify files were created assert files_saved == 10 - # Verify first CDC file structure - cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz" + # Verify first CDC file structure (with config hash) + latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash)) assert cdc_path.exists() import numpy as np @@ -73,7 +75,8 @@ class TestCDCPreprocessorIntegration: Test CDC preprocessing with variable-size latents (bucketing) """ preprocessor = CDCPreprocessor( - k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu" + k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash ) # Add 5 latents of shape (16, 4, 4) @@ -109,9 +112,15 @@ class TestCDCPreprocessorIntegration: assert files_saved == 10 import numpy as np - # Check shapes are stored in individual files - data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz") - data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz") + # Check shapes are stored in individual files (with config hash) + cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash + ) + cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash + ) + data_0 = np.load(cdc_path_0) + data_5 = np.load(cdc_path_5) assert tuple(data_0['shape']) == (16, 4, 4) assert tuple(data_5['shape']) == (16, 8, 8) @@ -128,7 +137,8 @@ class TestDeviceConsistency: """ # Create CDC cache on CPU preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash ) shape = (16, 32, 32) @@ -148,7 +158,7 @@ class TestDeviceConsistency: preprocessor.compute_all() - dataset = GammaBDataset(device="cpu") + dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") @@ -175,7 +185,8 @@ class TestDeviceConsistency: """ # Create CDC cache on CPU preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash ) shape = (16, 32, 32) @@ -195,7 +206,7 @@ class TestDeviceConsistency: preprocessor.compute_all() - dataset = GammaBDataset(device="cpu") + dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) # Create noise and timesteps noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True) @@ -236,7 +247,8 @@ class TestCDCEndToEnd: """ # Step 1: Preprocess latents preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash ) num_samples = 10 @@ -257,8 +269,8 @@ class TestCDCEndToEnd: files_saved = preprocessor.compute_all() assert files_saved == num_samples - # Step 2: Load with GammaBDataset - gamma_b_dataset = GammaBDataset(device="cpu") + # Step 2: Load with GammaBDataset (use config hash) + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) # Step 3: Use in mock training scenario batch_size = 3 diff --git a/tests/library/test_cdc_rescaling_recommendations.py b/tests/library/test_cdc_rescaling_recommendations.py deleted file mode 100644 index 75e8c3fb..00000000 --- a/tests/library/test_cdc_rescaling_recommendations.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -Tests to validate the CDC rescaling recommendations from paper review. - -These tests check: -1. Gamma parameter interaction with rescaling -2. Spatial adaptivity of eigenvalue scaling -3. Verification of fixed vs adaptive rescaling behavior -""" - -import numpy as np -import pytest -import torch -from safetensors import safe_open - -from library.cdc_fm import CDCPreprocessor - - -class TestGammaRescalingInteraction: - """Test that gamma parameter works correctly with eigenvalue rescaling""" - - def test_gamma_scales_eigenvalues_correctly(self, tmp_path): - """Verify gamma multiplier is applied correctly after rescaling""" - # Create two preprocessors with different gamma values - gamma_values = [0.5, 1.0, 2.0] - eigenvalue_results = {} - - for gamma in gamma_values: - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu" - ) - - # Add identical deterministic data for all runs - for i in range(10): - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / f"test_gamma_{gamma}.safetensors" - preprocessor.compute_all(save_path=output_path) - - # Extract eigenvalues - with safe_open(str(output_path), framework="pt", device="cpu") as f: - eigvals = f.get_tensor("eigenvalues/test_image_0").numpy() - eigenvalue_results[gamma] = eigvals - - # With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound - # Gamma 0.5: max eigenvalue should be ~0.5 - # Gamma 1.0: max eigenvalue should be ~1.0 - # Gamma 2.0: max eigenvalue should be ~2.0 - - max_0p5 = np.max(eigenvalue_results[0.5]) - max_1p0 = np.max(eigenvalue_results[1.0]) - max_2p0 = np.max(eigenvalue_results[2.0]) - - assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}" - assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}" - assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}" - - # All should have min of 1e-3 (clamp lower bound) - assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3 - assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3 - assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3 - - print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}") - print(f"✓ Gamma 1.0 max: {max_1p0:.4f}") - print(f"✓ Gamma 2.0 max: {max_2p0:.4f}") - - def test_large_gamma_maintains_reasonable_scale(self, tmp_path): - """Verify that large gamma values don't cause eigenvalue explosion""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu" - ) - - for i in range(10): - latent = torch.zeros(16, 4, 4, dtype=torch.float32) - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_large_gamma.safetensors" - preprocessor.compute_all(save_path=output_path) - - with safe_open(str(output_path), framework="pt", device="cpu") as f: - all_eigvals = [] - for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - all_eigvals.extend(eigvals) - - max_eigval = np.max(all_eigvals) - mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6]) - - # With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0 - # But they should still be reasonable (not exploding) - assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma" - assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma" - - print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}") - - -class TestSpatialAdaptivityOfRescaling: - """Test spatial variation in eigenvalue scaling""" - - def test_eigenvalues_vary_spatially(self, tmp_path): - """Verify eigenvalues differ across spatially separated clusters""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - # Create two distinct clusters in latent space - # Cluster 1: Tight cluster (low variance) - deterministic spread - for i in range(10): - latent = torch.zeros(16, 4, 4) - # Small variation around 0 - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - # Cluster 2: Loose cluster (high variance) - deterministic spread - for i in range(10, 20): - latent = torch.ones(16, 4, 4) * 5.0 - # Large variation around 5.0 - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2 - metadata = {'image_key': f'test_image_{i}'} - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_spatial_variation.safetensors" - preprocessor.compute_all(save_path=output_path) - - with safe_open(str(output_path), framework="pt", device="cpu") as f: - # Get eigenvalues from both clusters - cluster1_eigvals = [] - cluster2_eigvals = [] - - for i in range(10): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - cluster1_eigvals.append(np.max(eigvals)) - - for i in range(10, 20): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - cluster2_eigvals.append(np.max(eigvals)) - - cluster1_mean = np.mean(cluster1_eigvals) - cluster2_mean = np.mean(cluster2_eigvals) - - print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}") - print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}") - - # With fixed target_scale rescaling, eigenvalues should be similar - # despite different local geometry - # This demonstrates the limitation of fixed rescaling - ratio = cluster2_mean / (cluster1_mean + 1e-10) - print(f"✓ Ratio (loose/tight): {ratio:.2f}") - - # Both should be rescaled to similar magnitude (~0.1 due to target_scale) - assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range" - assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range" - - -class TestFixedVsAdaptiveRescaling: - """Compare current fixed rescaling vs paper's adaptive approach""" - - def test_current_rescaling_is_uniform(self, tmp_path): - """Demonstrate that current rescaling produces uniform eigenvalue scales""" - preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" - ) - - # Create samples with varying local density - deterministic - for i in range(20): - latent = torch.zeros(16, 4, 4) - # Some samples clustered, some isolated - if i < 10: - # Dense cluster around origin - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05 - else: - # Isolated points - larger offset - for c in range(16): - for h in range(4): - for w in range(4): - latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0 - - metadata = {'image_key': f'test_image_{i}'} - - - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) - - output_path = tmp_path / "test_uniform_rescaling.safetensors" - preprocessor.compute_all(save_path=output_path) - - with safe_open(str(output_path), framework="pt", device="cpu") as f: - max_eigenvalues = [] - for i in range(20): - eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy() - vals = eigvals[eigvals > 1e-6] - if vals.size: # at least one valid eigen-value - max_eigenvalues.append(vals.max()) - - if not max_eigenvalues: # safeguard against empty list - pytest.skip("no valid eigen-values found") - - max_eigenvalues = np.array(max_eigenvalues) - - # Check coefficient of variation (std / mean) - cv = max_eigenvalues.std() / max_eigenvalues.mean() - - print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]") - print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}") - print(f"✓ Coefficient of variation: {cv:.4f}") - - # With clamping, eigenvalues should have relatively low variation - assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping" - # Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0]) - assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range" - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py index c7fb2d85..6815b4da 100644 --- a/tests/library/test_cdc_standalone.py +++ b/tests/library/test_cdc_standalone.py @@ -1,132 +1,176 @@ """ -Standalone tests for CDC-FM integration. +Standalone tests for CDC-FM per-file caching. -These tests focus on CDC-FM specific functionality without importing -the full training infrastructure that has problematic dependencies. +These tests focus on the current CDC-FM per-file caching implementation +with hash-based cache validation. """ from pathlib import Path import pytest import torch -from safetensors.torch import save_file +import numpy as np from library.cdc_fm import CDCPreprocessor, GammaBDataset class TestCDCPreprocessor: - """Test CDC preprocessing functionality""" + """Test CDC preprocessing functionality with per-file caching""" def test_cdc_preprocessor_basic_workflow(self, tmp_path): """Test basic CDC preprocessing with small dataset""" preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] ) # Add 10 small latents for i in range(10): latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) - # Compute and save - output_path = tmp_path / "test_gamma_b.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) + # Compute and save (creates per-file CDC caches) + files_saved = preprocessor.compute_all() - # Verify file was created - assert Path(result_path).exists() + # Verify files were created + assert files_saved == 10 - # Verify structure - from safetensors import safe_open + # Verify first CDC file structure + latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash)) + assert cdc_path.exists() - with safe_open(str(result_path), framework="pt", device="cpu") as f: - assert f.get_tensor("metadata/num_samples").item() == 10 - assert f.get_tensor("metadata/k_neighbors").item() == 5 - assert f.get_tensor("metadata/d_cdc").item() == 4 + data = np.load(cdc_path) + assert data['k_neighbors'] == 5 + assert data['d_cdc'] == 4 - # Check first sample - eigvecs = f.get_tensor("eigenvectors/test_image_0") - eigvals = f.get_tensor("eigenvalues/test_image_0") + # Check eigenvectors and eigenvalues + eigvecs = data['eigenvectors'] + eigvals = data['eigenvalues'] - assert eigvecs.shape[0] == 4 # d_cdc - assert eigvals.shape[0] == 4 # d_cdc + assert eigvecs.shape[0] == 4 # d_cdc + assert eigvals.shape[0] == 4 # d_cdc def test_cdc_preprocessor_different_shapes(self, tmp_path): """Test CDC preprocessing with variable-size latents (bucketing)""" preprocessor = CDCPreprocessor( - k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu" + k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] ) # Add 5 latents of shape (16, 4, 4) for i in range(5): latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) # Add 5 latents of different shape (16, 8, 8) for i in range(5, 10): latent = torch.randn(16, 8, 8, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz") metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) # Compute and save - output_path = tmp_path / "test_gamma_b_multi.safetensors" - result_path = preprocessor.compute_all(save_path=output_path) + files_saved = preprocessor.compute_all() # Verify both shape groups were processed - from safetensors import safe_open + assert files_saved == 10 - with safe_open(str(result_path), framework="pt", device="cpu") as f: - # Check shapes are stored - shape_0 = f.get_tensor("shapes/test_image_0") - shape_5 = f.get_tensor("shapes/test_image_5") + # Check shapes are stored in individual files + cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash + ) + cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash + ) - assert tuple(shape_0.tolist()) == (16, 4, 4) - assert tuple(shape_5.tolist()) == (16, 8, 8) + data_0 = np.load(cdc_path_0) + data_5 = np.load(cdc_path_5) + + assert tuple(data_0['shape']) == (16, 4, 4) + assert tuple(data_5['shape']) == (16, 8, 8) class TestGammaBDataset: - """Test GammaBDataset loading and retrieval""" + """Test GammaBDataset loading and retrieval with per-file caching""" @pytest.fixture def sample_cdc_cache(self, tmp_path): - """Create a sample CDC cache file for testing""" - cache_path = tmp_path / "test_gamma_b.safetensors" + """Create sample CDC cache files for testing""" + # Use 20 samples to ensure proper k-NN computation + # (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing) + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)], + adaptive_k=True, # Enable adaptive k for small dataset + min_bucket_size=5 + ) - # Create mock Γ_b data for 5 samples - tensors = { - "metadata/num_samples": torch.tensor([5]), - "metadata/k_neighbors": torch.tensor([10]), - "metadata/d_cdc": torch.tensor([4]), - "metadata/gamma": torch.tensor([1.0]), - } + # Create 20 samples + latents_npz_paths = [] + for i in range(20): + latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened + latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) - # Add shape and CDC data for each sample - for i in range(5): - tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W - tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d - tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive - - save_file(tensors, str(cache_path)) - return cache_path + preprocessor.compute_all() + return tmp_path, latents_npz_paths, preprocessor.config_hash def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache): - """Test that GammaBDataset loads metadata correctly""" - gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + """Test that GammaBDataset loads CDC files correctly""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) - assert gamma_b_dataset.num_samples == 5 - assert gamma_b_dataset.d_cdc == 4 + # Get components for first sample + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu") + + # Check shapes + assert eigvecs.shape[0] == 1 # batch size + assert eigvecs.shape[1] == 4 # d_cdc + assert eigvals.shape == (1, 4) # batch, d_cdc def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache): """Test retrieving Γ_b^(1/2) components""" - gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) - # Get Γ_b for indices [0, 2, 4] - indices = [0, 2, 4] - eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu") + # Get Γ_b for paths [0, 2, 4] + paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] + eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") # Check shapes - assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d) + assert eigenvectors.shape[0] == 3 # batch + assert eigenvectors.shape[1] == 4 # d_cdc assert eigenvalues.shape == (3, 4) # (batch, d_cdc) # Check values are positive @@ -134,14 +178,16 @@ class TestGammaBDataset: def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache): """Test compute_sigma_t_x returns x unchanged at t=0""" - gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) # Create test latents (batch of 3, matching d=1024 flattened) x = torch.randn(3, 1024) # B, d (flattened) t = torch.zeros(3) # t = 0 for all samples # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu") + paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]] + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -150,13 +196,15 @@ class TestGammaBDataset: def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache): """Test compute_sigma_t_x returns correct shape""" - gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) x = torch.randn(2, 1024) # B, d (flattened) t = torch.tensor([0.3, 0.7]) # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu") + paths = [latents_npz_paths[1], latents_npz_paths[3]] + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -165,13 +213,15 @@ class TestGammaBDataset: def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache): """Test compute_sigma_t_x produces finite values""" - gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) x = torch.randn(3, 1024) # B, d (flattened) t = torch.rand(3) # Random timesteps in [0, 1] # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu") + paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -187,31 +237,39 @@ class TestCDCEndToEnd: """Test complete workflow: preprocess -> save -> load -> use""" # Step 1: Preprocess latents preprocessor = CDCPreprocessor( - k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] ) num_samples = 10 + latents_npz_paths = [] for i in range(num_samples): latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + latents_npz_paths.append(latents_npz_path) metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata) + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) - output_path = tmp_path / "cdc_gamma_b.safetensors" - cdc_path = preprocessor.compute_all(save_path=output_path) + files_saved = preprocessor.compute_all() + assert files_saved == num_samples # Step 2: Load with GammaBDataset - gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu") - - assert gamma_b_dataset.num_samples == num_samples + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) # Step 3: Use in mock training scenario batch_size = 3 batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) batch_t = torch.rand(batch_size) - image_keys = ['test_image_0', 'test_image_5', 'test_image_9'] + paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu") + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu") # Compute geometry-aware noise sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) diff --git a/tests/library/test_cdc_warning_throttling.py b/tests/library/test_cdc_warning_throttling.py deleted file mode 100644 index d8cba614..00000000 --- a/tests/library/test_cdc_warning_throttling.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -Test warning throttling for CDC shape mismatches. - -Ensures that duplicate warnings for the same sample are not logged repeatedly. -""" - -import pytest -import torch -import logging - -from library.cdc_fm import CDCPreprocessor, GammaBDataset -from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples - - -class TestWarningThrottling: - """Test that shape mismatch warnings are throttled""" - - @pytest.fixture(autouse=True) - def clear_warned_samples(self): - """Clear the warned samples set before each test""" - _cdc_warned_samples.clear() - yield - _cdc_warned_samples.clear() - - @pytest.fixture - def cdc_cache(self, tmp_path): - """Create a test CDC cache with one shape""" - preprocessor = CDCPreprocessor( - k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" - ) - - # Create cache with one specific shape - preprocessed_shape = (16, 32, 32) - for i in range(10): - latent = torch.randn(*preprocessed_shape, dtype=torch.float32) - metadata = {'image_key': f'test_image_{i}'} - preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata) - - cache_path = tmp_path / "test_throttle.safetensors" - preprocessor.compute_all(save_path=cache_path) - return cache_path - - def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog): - """ - Test that shape mismatch warning is only logged once per sample. - - Even if the same sample appears in multiple batches, only warn once. - """ - dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu") - - # Use different shape at runtime to trigger mismatch - runtime_shape = (16, 64, 64) - timesteps = torch.tensor([100.0], dtype=torch.float32) - image_keys = ['test_image_0'] # Same sample - - # First call - should warn - with caplog.at_level(logging.WARNING): - caplog.clear() - noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32) - _ = apply_cdc_noise_transformation( - noise=noise1, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have exactly one warning - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 1, "First call should produce exactly one warning" - assert "CDC shape mismatch" in warnings[0].message - - # Second call with same sample - should NOT warn - with caplog.at_level(logging.WARNING): - caplog.clear() - noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32) - _ = apply_cdc_noise_transformation( - noise=noise2, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have NO warnings - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 0, "Second call with same sample should not warn" - - # Third call with same sample - still should NOT warn - with caplog.at_level(logging.WARNING): - caplog.clear() - noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32) - _ = apply_cdc_noise_transformation( - noise=noise3, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 0, "Third call should still not warn" - - def test_different_samples_each_get_one_warning(self, cdc_cache, caplog): - """ - Test that different samples each get their own warning. - - Each unique sample should be warned about once. - """ - dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu") - - runtime_shape = (16, 64, 64) - timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32) - - # First batch: samples 0, 1, 2 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have 3 warnings (one per sample) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 3, "Should warn for each of the 3 samples" - - # Second batch: same samples 0, 1, 2 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(3, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_0', 'test_image_1', 'test_image_2'] - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have NO warnings (already warned) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 0, "Should not warn again for same samples" - - # Third batch: new samples 3, 4 - with caplog.at_level(logging.WARNING): - caplog.clear() - noise = torch.randn(2, *runtime_shape, dtype=torch.float32) - image_keys = ['test_image_3', 'test_image_4'] - timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32) - - _ = apply_cdc_noise_transformation( - noise=noise, - timesteps=timesteps, - num_timesteps=1000, - gamma_b_dataset=dataset, - image_keys=image_keys, - device="cpu" - ) - - # Should have 2 warnings (new samples) - warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"] - assert len(warnings) == 2, "Should warn for each of the 2 new samples" - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"])