mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Fix CDC tests to new format and deprecate old tests
This commit is contained in:
@@ -1,228 +0,0 @@
|
||||
"""
|
||||
Test adaptive k_neighbors functionality in CDC-FM.
|
||||
|
||||
Verifies that adaptive k properly adjusts based on bucket sizes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestAdaptiveK:
|
||||
"""Test adaptive k_neighbors behavior"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_cache_path(self, tmp_path):
|
||||
"""Create temporary cache path"""
|
||||
return tmp_path / "adaptive_k_test.safetensors"
|
||||
|
||||
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
|
||||
"""
|
||||
Test that fixed k mode skips buckets with < k_neighbors samples.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=False # Fixed mode
|
||||
)
|
||||
|
||||
# Add 10 samples (< k=32, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (Gaussian fallback)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (fallback)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 20 samples (< k=32, should use k=19)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify non-zero (CDC computed)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should NOT be all zeros (CDC was computed)
|
||||
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode skips buckets below min_bucket_size.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=16
|
||||
)
|
||||
|
||||
# Add 10 samples (< min_bucket_size=16, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (skipped due to min_bucket_size)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (skipped)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
|
||||
"""
|
||||
Test adaptive k with multiple buckets of different sizes.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Bucket 1: 10 samples (adaptive k=9)
|
||||
for i in range(10):
|
||||
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=(4, 16, 16),
|
||||
metadata={'image_key': f'small_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 2: 40 samples (full k=32)
|
||||
for i in range(40):
|
||||
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=100+i,
|
||||
shape=(4, 32, 32),
|
||||
metadata={'image_key': f'large_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 3: 5 samples (< min=8, skipped)
|
||||
for i in range(5):
|
||||
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=200+i,
|
||||
shape=(4, 8, 8),
|
||||
metadata={'image_key': f'tiny_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
|
||||
# Bucket 1: Should have CDC (non-zero)
|
||||
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
|
||||
|
||||
# Bucket 2: Should have CDC (non-zero)
|
||||
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
|
||||
|
||||
# Bucket 3: Should be skipped (zeros)
|
||||
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
|
||||
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
|
||||
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k uses full k_neighbors when bucket is large enough.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16,
|
||||
k_bandwidth=4,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 50 samples (> k=16, should use full k=16)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(50):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify CDC was computed
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should have non-zero eigenvalues
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
# Eigenvalues should be positive
|
||||
assert (eigvals >= 0).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
Test device consistency handling in CDC noise transformation.
|
||||
|
||||
Ensures that device mismatches are handled gracefully.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation
|
||||
|
||||
|
||||
class TestDeviceConsistency:
|
||||
"""Test device consistency validation"""
|
||||
|
||||
@pytest.fixture
|
||||
def cdc_cache(self, tmp_path):
|
||||
"""Create a test CDC cache"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_device.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
return cache_path
|
||||
|
||||
def test_matching_devices_no_warning(self, cdc_cache, caplog):
|
||||
"""
|
||||
Test that no warnings are emitted when devices match.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
shape = (16, 32, 32)
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# No device mismatch warnings
|
||||
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
|
||||
assert len(device_warnings) == 0, "Should not warn when devices match"
|
||||
|
||||
def test_device_mismatch_warning_and_transfer(self, cdc_cache, caplog):
|
||||
"""
|
||||
Test that device mismatch is detected, warned, and handled.
|
||||
|
||||
This simulates the case where noise is on one device but CDC matrices
|
||||
are requested for another device.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
shape = (16, 32, 32)
|
||||
# Create noise on CPU
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
# But request CDC matrices for a different device string
|
||||
# (In practice this would be "cuda" vs "cpu", but we simulate with string comparison)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
|
||||
# Use a different device specification to trigger the check
|
||||
# We'll use "cpu" vs "cpu:0" as an example of string mismatch
|
||||
result = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu" # Same actual device, consistent string
|
||||
)
|
||||
|
||||
# Should complete without errors
|
||||
assert result is not None
|
||||
assert result.shape == noise.shape
|
||||
|
||||
def test_transformation_works_after_device_transfer(self, cdc_cache):
|
||||
"""
|
||||
Test that CDC transformation produces valid output even if devices differ.
|
||||
|
||||
The function should handle device transfer gracefully.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
shape = (16, 32, 32)
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
image_keys = ['test_image_0', 'test_image_1']
|
||||
|
||||
result = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Verify output is valid
|
||||
assert result.shape == noise.shape
|
||||
assert result.device == noise.device
|
||||
assert result.requires_grad # Gradients should still work
|
||||
assert not torch.isnan(result).any()
|
||||
assert not torch.isinf(result).any()
|
||||
|
||||
# Verify gradients flow
|
||||
loss = result.sum()
|
||||
loss.backward()
|
||||
assert noise.grad is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,146 +0,0 @@
|
||||
"""
|
||||
Test CDC-FM dimension handling and fallback mechanisms.
|
||||
|
||||
This module tests the behavior of the CDC Flow Matching implementation
|
||||
when encountering latents with different dimensions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
class TestDimensionHandling:
|
||||
def setup_method(self):
|
||||
"""Prepare consistent test environment"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def test_mixed_dimension_fallback(self):
|
||||
"""
|
||||
Verify that preprocessor falls back to standard noise for mixed-dimension batches
|
||||
"""
|
||||
# Prepare preprocessor with debug mode
|
||||
preprocessor = CDCPreprocessor(debug=True)
|
||||
|
||||
# Different-sized latents (3D: channels, height, width)
|
||||
latents = [
|
||||
torch.randn(3, 32, 64), # First latent: 3x32x64
|
||||
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
# Try adding mixed-dimension latents
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_mixed_image_{i}'}
|
||||
)
|
||||
|
||||
try:
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
except ValueError as e:
|
||||
# If implementation raises ValueError, that's acceptable
|
||||
assert "Dimension mismatch" in str(e)
|
||||
return
|
||||
|
||||
# Check for dimension-related log messages
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension mismatch" in msg.lower()
|
||||
]
|
||||
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
|
||||
|
||||
# Load results and verify fallback
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
# Check metadata about samples with/without CDC
|
||||
assert dataset.num_samples == len(latents), "All samples should be processed"
|
||||
|
||||
def test_adaptive_k_with_dimension_constraints(self):
|
||||
"""
|
||||
Test adaptive k-neighbors behavior with dimension constraints
|
||||
"""
|
||||
# Prepare preprocessor with adaptive k and small bucket size
|
||||
preprocessor = CDCPreprocessor(
|
||||
adaptive_k=True,
|
||||
min_bucket_size=5,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Generate latents with similar but not identical dimensions
|
||||
base_latent = torch.randn(3, 32, 64)
|
||||
similar_latents = [
|
||||
base_latent,
|
||||
torch.randn(3, 32, 65), # Slightly different dimension
|
||||
torch.randn(3, 32, 66) # Another slightly different dimension
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add similar latents
|
||||
for i, latent in enumerate(similar_latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_adaptive_k_image_{i}'}
|
||||
)
|
||||
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Load results
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
# Verify samples processed
|
||||
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
|
||||
|
||||
# Optional: Check warnings about dimension differences
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension" in msg.lower()
|
||||
]
|
||||
print(f"Dimension-related warnings: {dimension_warnings}")
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure custom markers for dimension handling tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
|
||||
)
|
||||
@@ -1,310 +0,0 @@
|
||||
"""
|
||||
Comprehensive CDC Dimension Handling and Warning Tests
|
||||
|
||||
This module tests:
|
||||
1. Dimension mismatch detection and fallback mechanisms
|
||||
2. Warning throttling for shape mismatches
|
||||
3. Adaptive k-neighbors behavior with dimension constraints
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
|
||||
|
||||
|
||||
class TestDimensionHandlingAndWarnings:
|
||||
"""
|
||||
Comprehensive testing of dimension handling, noise injection, and warning systems
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_warned_samples(self):
|
||||
"""Clear the warned samples set before each test"""
|
||||
_cdc_warned_samples.clear()
|
||||
yield
|
||||
_cdc_warned_samples.clear()
|
||||
|
||||
def test_mixed_dimension_fallback(self):
|
||||
"""
|
||||
Verify that preprocessor falls back to standard noise for mixed-dimension batches
|
||||
"""
|
||||
# Prepare preprocessor with debug mode
|
||||
preprocessor = CDCPreprocessor(debug=True)
|
||||
|
||||
# Different-sized latents (3D: channels, height, width)
|
||||
latents = [
|
||||
torch.randn(3, 32, 64), # First latent: 3x32x64
|
||||
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
# Try adding mixed-dimension latents
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_mixed_image_{i}'}
|
||||
)
|
||||
|
||||
try:
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
except ValueError as e:
|
||||
# If implementation raises ValueError, that's acceptable
|
||||
assert "Dimension mismatch" in str(e)
|
||||
return
|
||||
|
||||
# Check for dimension-related log messages
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension mismatch" in msg.lower()
|
||||
]
|
||||
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
|
||||
|
||||
# Load results and verify fallback
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
# Check metadata about samples with/without CDC
|
||||
assert dataset.num_samples == len(latents), "All samples should be processed"
|
||||
|
||||
def test_adaptive_k_with_dimension_constraints(self):
|
||||
"""
|
||||
Test adaptive k-neighbors behavior with dimension constraints
|
||||
"""
|
||||
# Prepare preprocessor with adaptive k and small bucket size
|
||||
preprocessor = CDCPreprocessor(
|
||||
adaptive_k=True,
|
||||
min_bucket_size=5,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Generate latents with similar but not identical dimensions
|
||||
base_latent = torch.randn(3, 32, 64)
|
||||
similar_latents = [
|
||||
base_latent,
|
||||
torch.randn(3, 32, 65), # Slightly different dimension
|
||||
torch.randn(3, 32, 66) # Another slightly different dimension
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add similar latents
|
||||
for i, latent in enumerate(similar_latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_adaptive_k_image_{i}'}
|
||||
)
|
||||
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Load results
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
# Verify samples processed
|
||||
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
|
||||
|
||||
# Optional: Check warnings about dimension differences
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension" in msg.lower()
|
||||
]
|
||||
print(f"Dimension-related warnings: {dimension_warnings}")
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
def test_warning_only_logged_once_per_sample(self, caplog):
|
||||
"""
|
||||
Test that shape mismatch warning is only logged once per sample.
|
||||
|
||||
Even if the same sample appears in multiple batches, only warn once.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create cache with one specific shape
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
for i in range(10):
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
|
||||
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
# Use different shape at runtime to trigger mismatch
|
||||
runtime_shape = (16, 64, 64)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
image_keys = ['test_image_0'] # Same sample
|
||||
|
||||
# First call - should warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise1,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have exactly one warning
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 1, "First call should produce exactly one warning"
|
||||
assert "CDC shape mismatch" in warnings[0].message
|
||||
|
||||
# Second call with same sample - should NOT warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise2,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have NO warnings
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 0, "Second call with same sample should not warn"
|
||||
|
||||
def test_different_samples_each_get_one_warning(self, caplog):
|
||||
"""
|
||||
Test that different samples each get their own warning.
|
||||
|
||||
Each unique sample should be warned about once.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create cache with specific shape
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
for i in range(10):
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cdc_path = preprocessor.compute_all(save_path=tmp_file.name)
|
||||
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
runtime_shape = (16, 64, 64)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
|
||||
|
||||
# First batch: samples 0, 1, 2
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have 3 warnings (one per sample)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 3, "Should warn for each of the 3 samples"
|
||||
|
||||
# Second batch: same samples 0, 1, 2
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have NO warnings (already warned)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 0, "Should not warn again for same samples"
|
||||
|
||||
# Third batch: new samples 3, 4
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_3', 'test_image_4']
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have 2 warnings (new samples)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure custom markers for dimension handling and warning tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"warning_throttling: mark test for CDC-FM warning suppression"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,164 +0,0 @@
|
||||
"""
|
||||
Tests using realistic high-dimensional data to catch scaling bugs.
|
||||
|
||||
This test uses realistic VAE-like latents to ensure eigenvalue normalization
|
||||
works correctly on real-world data.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestRealisticDataScaling:
|
||||
"""Test eigenvalue scaling with realistic high-dimensional data"""
|
||||
|
||||
def test_high_dimensional_latents_not_saturated(self, tmp_path):
|
||||
"""
|
||||
Verify that high-dimensional realistic latents don't saturate eigenvalues.
|
||||
|
||||
This test simulates real FLUX training data:
|
||||
- High dimension (16×64×64 = 65536)
|
||||
- Varied content (different variance in different regions)
|
||||
- Realistic magnitude (VAE output scale)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create 20 samples with realistic varied structure
|
||||
for i in range(20):
|
||||
# High-dimensional latent like FLUX
|
||||
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
|
||||
|
||||
# Create varied structure across the latent
|
||||
# Different channels have different patterns (realistic for VAE)
|
||||
for c in range(16):
|
||||
# Some channels have gradients
|
||||
if c < 4:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = (h + w) / 128.0
|
||||
# Some channels have patterns
|
||||
elif c < 8:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
|
||||
# Some channels are more uniform
|
||||
else:
|
||||
latent[c, :, :] = c * 0.1
|
||||
|
||||
# Add per-sample variation (different "subjects")
|
||||
latent = latent * (1.0 + i * 0.2)
|
||||
|
||||
# Add realistic VAE-like noise/variation
|
||||
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are NOT all saturated at 1.0
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical: eigenvalues should NOT all be 1.0
|
||||
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
|
||||
total = len(non_zero_eigvals)
|
||||
percent_at_max = (at_max / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
|
||||
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
|
||||
|
||||
# FAIL if too many eigenvalues are saturated at 1.0
|
||||
assert percent_at_max < 80, (
|
||||
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
|
||||
f"This indicates the normalization bug - raw eigenvalues are not being "
|
||||
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
|
||||
)
|
||||
|
||||
# Should have good diversity
|
||||
assert np.std(non_zero_eigvals) > 0.1, (
|
||||
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
|
||||
f"Should see diverse eigenvalues, not all the same value."
|
||||
)
|
||||
|
||||
# Mean should be in reasonable range (not all 1.0)
|
||||
mean_eigval = np.mean(non_zero_eigvals)
|
||||
assert 0.05 < mean_eigval < 0.9, (
|
||||
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
|
||||
f"If mean ≈ 1.0, eigenvalues are saturated."
|
||||
)
|
||||
|
||||
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
|
||||
"""
|
||||
Test that datasets with more variance produce more diverse eigenvalues.
|
||||
|
||||
This ensures the normalization preserves relative information.
|
||||
"""
|
||||
# Create two preprocessors with different data variance
|
||||
results = {}
|
||||
|
||||
for variance_scale in [0.5, 2.0]:
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(15):
|
||||
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
|
||||
|
||||
# Create varied patterns
|
||||
for c in range(16):
|
||||
for h in range(32):
|
||||
for w in range(32):
|
||||
latent[c, h, w] = (
|
||||
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
|
||||
)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
eigvals = []
|
||||
for i in range(15):
|
||||
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
eigvals.extend(ev[ev > 1e-6])
|
||||
|
||||
results[variance_scale] = {
|
||||
'mean': np.mean(eigvals),
|
||||
'std': np.std(eigvals),
|
||||
'range': (np.min(eigvals), np.max(eigvals))
|
||||
}
|
||||
|
||||
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
|
||||
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
|
||||
|
||||
# Both should have diversity (not saturated)
|
||||
for scale in [0.5, 2.0]:
|
||||
assert results[scale]['std'] > 0.1, (
|
||||
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,252 +0,0 @@
|
||||
"""
|
||||
Tests to verify CDC eigenvalue scaling is correct.
|
||||
|
||||
These tests ensure eigenvalues are properly scaled to prevent training loss explosion.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestEigenvalueScaling:
|
||||
"""Test that eigenvalues are properly scaled to reasonable ranges"""
|
||||
|
||||
def test_eigenvalues_in_correct_range(self, tmp_path):
|
||||
"""Verify eigenvalues are scaled to ~0.01-1.0 range, not millions"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Add deterministic latents with structured patterns
|
||||
for i in range(10):
|
||||
# Create gradient pattern: values from 0 to 2.0 across spatial dims
|
||||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||
# Add per-sample variation
|
||||
latent = latent + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are in correct range
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
|
||||
# Filter out zero eigenvalues (from padding when k < d_cdc)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical assertions for eigenvalue scale
|
||||
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
|
||||
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
|
||||
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
|
||||
|
||||
# Check sqrt (used in noise) is reasonable
|
||||
sqrt_max = np.sqrt(all_eigvals.max())
|
||||
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||||
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ sqrt(max): {sqrt_max:.4f}")
|
||||
|
||||
def test_eigenvalues_not_all_zero(self, tmp_path):
|
||||
"""Ensure eigenvalues are not all zero (indicating computation failure)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# With clamping, eigenvalues will be in range [1e-3, gamma*1.0]
|
||||
# Check that we have some non-zero eigenvalues
|
||||
assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed"
|
||||
|
||||
# Check they're in the expected clamped range
|
||||
assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}"
|
||||
assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}"
|
||||
|
||||
print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||||
print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
|
||||
def test_fp16_storage_no_overflow(self, tmp_path):
|
||||
"""Verify fp16 storage doesn't overflow (max fp16 = 65,504)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern with higher magnitude
|
||||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
||||
latent = latent + i * 0.3
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check dtype is fp16
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
|
||||
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
||||
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
||||
|
||||
# Check no values near fp16 max (would indicate overflow)
|
||||
FP16_MAX = 65504
|
||||
max_eigval = eigvals.max().item()
|
||||
|
||||
assert max_eigval < 100, (
|
||||
f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. "
|
||||
f"May indicate overflow (fp16 max = {FP16_MAX})"
|
||||
)
|
||||
|
||||
print(f"\n✓ Storage dtype: {eigvals.dtype}")
|
||||
print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)")
|
||||
|
||||
def test_latent_magnitude_preserved(self, tmp_path):
|
||||
"""Verify latent magnitude is preserved (no unwanted normalization)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Store original latents with deterministic patterns
|
||||
original_latents = []
|
||||
for i in range(10):
|
||||
# Create structured pattern with known magnitude
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
||||
original_latents.append(latent.clone())
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Compute original latent statistics
|
||||
orig_std = torch.stack(original_latents).std().item()
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# The stored latents should preserve original magnitude
|
||||
stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples])
|
||||
|
||||
# Should be similar to original (within 20% due to potential batching effects)
|
||||
assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, (
|
||||
f"Stored latent std {stored_latents_std:.2f} differs too much from "
|
||||
f"original {orig_std:.2f}. Latent magnitude was not preserved."
|
||||
)
|
||||
|
||||
print(f"\n✓ Original latent std: {orig_std:.2f}")
|
||||
print(f"✓ Stored latent std: {stored_latents_std:.2f}")
|
||||
|
||||
|
||||
class TestTrainingLossScale:
|
||||
"""Test that eigenvalues produce reasonable loss magnitudes"""
|
||||
|
||||
def test_noise_magnitude_reasonable(self, tmp_path):
|
||||
"""Verify CDC noise has reasonable magnitude for training"""
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
# Create CDC cache with deterministic data
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Load and compute noise
|
||||
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
# Simulate training scenario with deterministic data
|
||||
batch_size = 3
|
||||
latents = torch.zeros(batch_size, 16, 4, 4)
|
||||
for b in range(batch_size):
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||
|
||||
# Check noise magnitude
|
||||
noise_std = noise.std().item()
|
||||
latent_std = latents.std().item()
|
||||
|
||||
# Noise should be similar magnitude to input latents (within 10x)
|
||||
ratio = noise_std / latent_std
|
||||
assert 0.1 < ratio < 10.0, (
|
||||
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
|
||||
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
|
||||
)
|
||||
|
||||
# Simulated MSE loss should be reasonable
|
||||
simulated_loss = torch.mean((noise - latents) ** 2).item()
|
||||
assert simulated_loss < 100.0, (
|
||||
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
|
||||
f"Should be O(0.1-1.0) for stable training."
|
||||
)
|
||||
|
||||
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
|
||||
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
Comprehensive CDC Eigenvalue Validation Tests
|
||||
|
||||
These tests ensure that eigenvalue computation and scaling work correctly
|
||||
across various scenarios, including:
|
||||
- Scaling to reasonable ranges
|
||||
- Handling high-dimensional data
|
||||
- Preserving latent information
|
||||
- Preventing computational artifacts
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestEigenvalueScaling:
|
||||
"""Verify eigenvalue scaling and computational properties"""
|
||||
|
||||
def test_eigenvalues_in_correct_range(self, tmp_path):
|
||||
"""
|
||||
Verify eigenvalues are scaled to ~0.01-1.0 range, not millions.
|
||||
|
||||
Ensures:
|
||||
- No numerical explosions
|
||||
- Reasonable eigenvalue magnitudes
|
||||
- Consistent scaling across samples
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create deterministic latents with structured patterns
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||
latent = latent + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are in correct range
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical assertions for eigenvalue scale
|
||||
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
|
||||
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
|
||||
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
|
||||
|
||||
# Check sqrt (used in noise) is reasonable
|
||||
sqrt_max = np.sqrt(all_eigvals.max())
|
||||
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||||
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ sqrt(max): {sqrt_max:.4f}")
|
||||
|
||||
def test_high_dimensional_latents_scaling(self, tmp_path):
|
||||
"""
|
||||
Verify scaling for high-dimensional realistic latents.
|
||||
|
||||
Key scenarios:
|
||||
- High-dimensional data (16×64×64)
|
||||
- Varied channel structures
|
||||
- Realistic VAE-like data
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create 20 samples with realistic varied structure
|
||||
for i in range(20):
|
||||
# High-dimensional latent like FLUX
|
||||
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
|
||||
|
||||
# Create varied structure across the latent
|
||||
for c in range(16):
|
||||
# Different patterns across channels
|
||||
if c < 4:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = (h + w) / 128.0
|
||||
elif c < 8:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
|
||||
else:
|
||||
latent[c, :, :] = c * 0.1
|
||||
|
||||
# Add per-sample variation
|
||||
latent = latent * (1.0 + i * 0.2)
|
||||
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are not all saturated
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
|
||||
total = len(non_zero_eigvals)
|
||||
percent_at_max = (at_max / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
|
||||
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
|
||||
|
||||
# Fail if too many eigenvalues are saturated
|
||||
assert percent_at_max < 80, (
|
||||
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
|
||||
f"Raw eigenvalues not scaled before clamping. "
|
||||
f"Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
|
||||
)
|
||||
|
||||
# Should have good diversity
|
||||
assert np.std(non_zero_eigvals) > 0.1, (
|
||||
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
|
||||
f"Should see diverse eigenvalues, not all the same."
|
||||
)
|
||||
|
||||
# Mean should be in reasonable range
|
||||
mean_eigval = np.mean(non_zero_eigvals)
|
||||
assert 0.05 < mean_eigval < 0.9, (
|
||||
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
|
||||
f"If mean ≈ 1.0, eigenvalues are saturated."
|
||||
)
|
||||
|
||||
def test_noise_magnitude_reasonable(self, tmp_path):
|
||||
"""
|
||||
Verify CDC noise has reasonable magnitude for training.
|
||||
|
||||
Ensures noise:
|
||||
- Has similar scale to input latents
|
||||
- Won't destabilize training
|
||||
- Preserves input variance
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Load and compute noise
|
||||
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
# Simulate training scenario with deterministic data
|
||||
batch_size = 3
|
||||
latents = torch.zeros(batch_size, 16, 4, 4)
|
||||
for b in range(batch_size):
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||
|
||||
# Check noise magnitude
|
||||
noise_std = noise.std().item()
|
||||
latent_std = latents.std().item()
|
||||
|
||||
# Noise should be similar magnitude to input latents (within 10x)
|
||||
ratio = noise_std / latent_std
|
||||
assert 0.1 < ratio < 10.0, (
|
||||
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
|
||||
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
|
||||
)
|
||||
|
||||
# Simulated MSE loss should be reasonable
|
||||
simulated_loss = torch.mean((noise - latents) ** 2).item()
|
||||
assert simulated_loss < 100.0, (
|
||||
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
|
||||
f"Should be O(0.1-1.0) for stable training."
|
||||
)
|
||||
|
||||
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
|
||||
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,297 +0,0 @@
|
||||
"""
|
||||
CDC Gradient Flow Verification Tests
|
||||
|
||||
This module provides testing of:
|
||||
1. Mock dataset gradient preservation
|
||||
2. Real dataset gradient flow
|
||||
3. Various time steps and computation paths
|
||||
4. Fallback and edge case scenarios
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation
|
||||
|
||||
|
||||
class MockGammaBDataset:
|
||||
"""
|
||||
Mock implementation of GammaBDataset for testing gradient flow
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Simple initialization that doesn't require file loading
|
||||
"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
eigenvalues: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simplified implementation of compute_sigma_t_x for testing
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
# Flatten x if it's 4D
|
||||
if x.dim() == 4:
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, -1) # (B, C*H*W)
|
||||
|
||||
# Validate dimensions
|
||||
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
|
||||
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
|
||||
|
||||
# Early return for t=0 with gradient preservation
|
||||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Compute Σ_t @ x
|
||||
# V^T x
|
||||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||||
|
||||
# sqrt(λ) * V^T x
|
||||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||||
|
||||
# V @ (sqrt(λ) * V^T x)
|
||||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||||
|
||||
# Interpolate between original and noisy latent
|
||||
result = (1 - t) * x + t * gamma_sqrt_x
|
||||
|
||||
# Restore original shape
|
||||
result = result.reshape(orig_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestCDCGradientFlow:
|
||||
"""
|
||||
Gradient flow testing for CDC noise transformations
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Prepare consistent test environment"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def test_mock_gradient_flow_near_zero_time_step(self):
|
||||
"""
|
||||
Verify gradient flow preservation for near-zero time steps
|
||||
using mock dataset with learnable time embeddings
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create a learnable time embedding with small initial value
|
||||
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
assert t.grad is not None, "Time embedding gradient should be computed"
|
||||
assert latent.grad is not None, "Input latent gradient should be computed"
|
||||
|
||||
# Check gradient magnitudes are non-zero
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
|
||||
|
||||
def test_gradient_flow_with_multiple_time_steps(self):
|
||||
"""
|
||||
Verify gradient flow across different time step values
|
||||
"""
|
||||
# Test time steps
|
||||
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
|
||||
|
||||
for time_val in time_steps:
|
||||
# Create a learnable time embedding
|
||||
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
|
||||
|
||||
# Reset gradients for next iteration
|
||||
t.grad.zero_() if t.grad is not None else None
|
||||
latent.grad.zero_() if latent.grad is not None else None
|
||||
|
||||
def test_gradient_flow_with_real_dataset(self, tmp_path):
|
||||
"""
|
||||
Test gradient flow with real CDC dataset
|
||||
"""
|
||||
# Create cache with uniform shapes
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_gradient.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
|
||||
# Prepare test noise
|
||||
torch.manual_seed(42)
|
||||
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
||||
|
||||
# Apply CDC transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Verify gradient flow
|
||||
assert noise_out.requires_grad, "Output should require gradients"
|
||||
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert noise.grad is not None, "Gradients should flow back to input noise"
|
||||
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
|
||||
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
|
||||
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
|
||||
|
||||
def test_gradient_flow_with_fallback(self, tmp_path):
|
||||
"""
|
||||
Test gradient flow when using Gaussian fallback (shape mismatch)
|
||||
|
||||
Ensures that cloned tensors maintain gradient flow correctly
|
||||
even when shape mismatch triggers Gaussian noise
|
||||
"""
|
||||
# Create cache with one shape
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
metadata = {'image_key': 'test_image_0'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_fallback_gradient.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
|
||||
# Use different shape at runtime (will trigger fallback)
|
||||
runtime_shape = (16, 64, 64)
|
||||
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
image_keys = ['test_image_0']
|
||||
|
||||
# Apply transformation (should fallback to Gaussian for this sample)
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Ensure gradients still flow through fallback path
|
||||
assert noise_out.requires_grad, "Fallback output should require gradients"
|
||||
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert noise.grad is not None, "Gradients should flow even in fallback case"
|
||||
assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure custom markers for CDC gradient flow tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"mock_dataset: mark test using mock dataset for simplified gradient testing"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"real_dataset: mark test using real dataset for comprehensive gradient testing"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
157
tests/library/test_cdc_hash_validation.py
Normal file
157
tests/library/test_cdc_hash_validation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Test CDC config hash generation and cache invalidation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestCDCConfigHash:
|
||||
"""
|
||||
Test that CDC config hash properly invalidates cache when dataset or parameters change
|
||||
"""
|
||||
|
||||
def test_same_config_produces_same_hash(self, tmp_path):
|
||||
"""
|
||||
Test that identical configurations produce identical hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash == preprocessor2.config_hash
|
||||
|
||||
def test_different_dataset_dirs_produce_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different dataset directories produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset2")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_k_neighbors_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different k_neighbors values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_d_cdc_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different d_cdc values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_different_gamma_produces_different_hash(self, tmp_path):
|
||||
"""
|
||||
Test that different gamma values produce different hashes
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash != preprocessor2.config_hash
|
||||
|
||||
def test_multiple_dataset_dirs_order_independent(self, tmp_path):
|
||||
"""
|
||||
Test that dataset directory order doesn't affect hash (they are sorted)
|
||||
"""
|
||||
preprocessor1 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")]
|
||||
)
|
||||
|
||||
preprocessor2 = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu",
|
||||
dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert preprocessor1.config_hash == preprocessor2.config_hash
|
||||
|
||||
def test_hash_length_is_8_chars(self, tmp_path):
|
||||
"""
|
||||
Test that hash is exactly 8 characters (hex)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
assert len(preprocessor.config_hash) == 8
|
||||
# Verify it's hex
|
||||
int(preprocessor.config_hash, 16) # Should not raise
|
||||
|
||||
def test_filename_includes_hash(self, tmp_path):
|
||||
"""
|
||||
Test that CDC filenames include the config hash
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
|
||||
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
|
||||
)
|
||||
|
||||
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash)
|
||||
|
||||
# Should be: image_0512x0768_flux_cdc_<hash>.npz
|
||||
expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz")
|
||||
assert cdc_path == expected
|
||||
|
||||
def test_backward_compatibility_no_hash(self, tmp_path):
|
||||
"""
|
||||
Test that get_cdc_npz_path works without hash (backward compatibility)
|
||||
"""
|
||||
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None)
|
||||
|
||||
# Should be: image_0512x0768_flux_cdc.npz (no hash suffix)
|
||||
expected = str(tmp_path / "image_0512x0768_flux_cdc.npz")
|
||||
assert cdc_path == expected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Test comparing interpolation vs pad/truncate for CDC preprocessing.
|
||||
|
||||
This test quantifies the difference between the two approaches.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TestInterpolationComparison:
|
||||
"""Compare interpolation vs pad/truncate"""
|
||||
|
||||
def test_intermediate_representation_quality(self):
|
||||
"""Compare intermediate representation quality for CDC computation"""
|
||||
# Create test latents with different sizes - deterministic
|
||||
latent_small = torch.zeros(16, 4, 4)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
|
||||
|
||||
latent_large = torch.zeros(16, 8, 8)
|
||||
for c in range(16):
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
|
||||
|
||||
target_h, target_w = 6, 6 # Median size
|
||||
|
||||
# Method 1: Interpolation
|
||||
def interpolate_method(latent, target_h, target_w):
|
||||
latent_input = latent.unsqueeze(0) # (1, C, H, W)
|
||||
latent_resized = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
)
|
||||
# Resize back
|
||||
C, H, W = latent.shape
|
||||
latent_reconstructed = F.interpolate(
|
||||
latent_resized, size=(H, W), mode='bilinear', align_corners=False
|
||||
)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Method 2: Pad/Truncate
|
||||
def pad_truncate_method(latent, target_h, target_w):
|
||||
C, H, W = latent.shape
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
current_dim = C * H * W
|
||||
|
||||
if current_dim == target_dim:
|
||||
latent_resized_flat = latent_flat
|
||||
elif current_dim > target_dim:
|
||||
# Truncate
|
||||
latent_resized_flat = latent_flat[:target_dim]
|
||||
else:
|
||||
# Pad
|
||||
latent_resized_flat = torch.zeros(target_dim)
|
||||
latent_resized_flat[:current_dim] = latent_flat
|
||||
|
||||
# Resize back
|
||||
if current_dim == target_dim:
|
||||
latent_reconstructed_flat = latent_resized_flat
|
||||
elif current_dim > target_dim:
|
||||
# Pad back
|
||||
latent_reconstructed_flat = torch.zeros(current_dim)
|
||||
latent_reconstructed_flat[:target_dim] = latent_resized_flat
|
||||
else:
|
||||
# Truncate back
|
||||
latent_reconstructed_flat = latent_resized_flat[:current_dim]
|
||||
|
||||
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Compare for small latent (needs padding)
|
||||
interp_error_small = interpolate_method(latent_small, target_h, target_w)
|
||||
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
|
||||
|
||||
# Compare for large latent (needs truncation)
|
||||
interp_error_large = interpolate_method(latent_large, target_h, target_w)
|
||||
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Reconstruction Error Comparison")
|
||||
print("=" * 60)
|
||||
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
||||
print(f" Interpolation error: {interp_error_small:.6f}")
|
||||
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
||||
if pad_error_small > 0:
|
||||
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
||||
else:
|
||||
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
||||
print(" BUT the intermediate representation is corrupted with zeros!")
|
||||
|
||||
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
||||
print(f" Interpolation error: {interp_error_large:.6f}")
|
||||
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
||||
if truncate_error_large > 0:
|
||||
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
|
||||
|
||||
# The key insight: Reconstruction error is NOT what matters for CDC!
|
||||
# What matters is the INTERMEDIATE representation quality used for geometry estimation.
|
||||
# Pad/truncate may have good reconstruction, but the intermediate is corrupted.
|
||||
|
||||
print("\nKey insight: For CDC, intermediate representation quality matters,")
|
||||
print("not reconstruction error. Interpolation preserves spatial structure.")
|
||||
|
||||
# Verify interpolation errors are reasonable
|
||||
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
|
||||
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
|
||||
|
||||
def test_spatial_structure_preservation(self):
|
||||
"""Test that interpolation preserves spatial structure better than pad/truncate"""
|
||||
# Create a latent with clear spatial pattern (gradient)
|
||||
C, H, W = 16, 4, 4
|
||||
latent = torch.zeros(C, H, W)
|
||||
for i in range(H):
|
||||
for j in range(W):
|
||||
latent[:, i, j] = i * W + j # Gradient pattern
|
||||
|
||||
target_h, target_w = 6, 6
|
||||
|
||||
# Interpolation
|
||||
latent_input = latent.unsqueeze(0)
|
||||
latent_interp = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Pad/truncate
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
latent_padded = torch.zeros(target_dim)
|
||||
latent_padded[:len(latent_flat)] = latent_flat
|
||||
latent_pad = latent_padded.reshape(C, target_h, target_w)
|
||||
|
||||
# Check gradient preservation
|
||||
# For interpolation, adjacent pixels should have smooth gradients
|
||||
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
|
||||
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
|
||||
|
||||
# For padding, there will be abrupt changes (gradient to zero)
|
||||
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
|
||||
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Spatial Structure Preservation")
|
||||
print("=" * 60)
|
||||
print("\nGradient smoothness (lower is smoother):")
|
||||
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
||||
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
||||
|
||||
# Padding introduces larger gradients due to abrupt zeros
|
||||
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
|
||||
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,412 +0,0 @@
|
||||
"""
|
||||
Performance and Interpolation Tests for CDC Flow Matching
|
||||
|
||||
This module provides testing of:
|
||||
1. Computational overhead
|
||||
2. Noise injection properties
|
||||
3. Interpolation vs. pad/truncate methods
|
||||
4. Spatial structure preservation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import time
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestCDCPerformanceAndInterpolation:
|
||||
"""
|
||||
Comprehensive performance testing for CDC Flow Matching
|
||||
Covers computational efficiency, noise properties, and interpolation quality
|
||||
"""
|
||||
|
||||
@pytest.fixture(params=[
|
||||
(3, 32, 32), # Small latent: typical for compact representations
|
||||
(3, 64, 64), # Medium latent: standard feature maps
|
||||
(3, 128, 128) # Large latent: high-resolution feature spaces
|
||||
])
|
||||
def latent_sizes(self, request):
|
||||
"""
|
||||
Parametrized fixture generating test cases for different latent sizes.
|
||||
|
||||
Rationale:
|
||||
- Tests robustness across various computational scales
|
||||
- Ensures consistent behavior from compact to large representations
|
||||
- Identifies potential dimensionality-related performance bottlenecks
|
||||
"""
|
||||
return request.param
|
||||
|
||||
def test_computational_overhead(self, latent_sizes):
|
||||
"""
|
||||
Measure computational overhead of CDC preprocessing across latent sizes.
|
||||
|
||||
Performance Verification Objectives:
|
||||
1. Verify preprocessing time scales predictably with input dimensions
|
||||
2. Ensure adaptive k-neighbors works efficiently
|
||||
3. Validate computational overhead remains within acceptable bounds
|
||||
|
||||
Performance Metrics:
|
||||
- Total preprocessing time
|
||||
- Per-sample processing time
|
||||
- Computational complexity indicators
|
||||
"""
|
||||
# Tuned preprocessing configuration
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=256, # Comprehensive neighborhood exploration
|
||||
d_cdc=8, # Geometric embedding dimensionality
|
||||
debug=True, # Enable detailed performance logging
|
||||
adaptive_k=True # Dynamic neighborhood size adjustment
|
||||
)
|
||||
|
||||
# Set a fixed random seed for reproducibility
|
||||
torch.manual_seed(42) # Consistent random generation
|
||||
|
||||
# Generate representative latent batch
|
||||
batch_size = 32
|
||||
latents = torch.randn(batch_size, *latent_sizes)
|
||||
|
||||
# Precision timing of preprocessing
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add latents with traceable metadata
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'perf_test_image_{i}'}
|
||||
)
|
||||
|
||||
# Compute CDC results
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Calculate precise preprocessing metrics
|
||||
end_time = time.perf_counter()
|
||||
preprocessing_time = end_time - start_time
|
||||
per_sample_time = preprocessing_time / batch_size
|
||||
|
||||
# Performance reporting and assertions
|
||||
input_volume = np.prod(latent_sizes)
|
||||
time_complexity_indicator = preprocessing_time / input_volume
|
||||
|
||||
print(f"\nPerformance Breakdown:")
|
||||
print(f" Latent Size: {latent_sizes}")
|
||||
print(f" Total Samples: {batch_size}")
|
||||
print(f" Input Volume: {input_volume}")
|
||||
print(f" Total Time: {preprocessing_time:.4f} seconds")
|
||||
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
|
||||
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
|
||||
|
||||
# Adaptive thresholds based on input dimensions
|
||||
max_total_time = 10.0 # Base threshold
|
||||
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
|
||||
|
||||
# Different time complexity thresholds for different latent sizes
|
||||
max_time_complexity = (
|
||||
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
|
||||
1e-4 # Standard latents
|
||||
)
|
||||
|
||||
# Performance assertions with informative error messages
|
||||
assert preprocessing_time < max_total_time, (
|
||||
f"Total preprocessing time exceeded threshold!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Total Time: {preprocessing_time:.4f} seconds\n"
|
||||
f" Threshold: {max_total_time} seconds"
|
||||
)
|
||||
|
||||
assert per_sample_time < max_per_sample_time, (
|
||||
f"Per-sample processing time exceeded threshold!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
|
||||
f" Threshold: {max_per_sample_time} seconds"
|
||||
)
|
||||
|
||||
# More adaptable time complexity check
|
||||
assert time_complexity_indicator < max_time_complexity, (
|
||||
f"Time complexity scaling exceeded expectations!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Input Volume: {input_volume}\n"
|
||||
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
|
||||
f" Threshold: {max_time_complexity} seconds/voxel"
|
||||
)
|
||||
|
||||
def test_noise_distribution(self, latent_sizes):
|
||||
"""
|
||||
Verify CDC noise injection quality and properties.
|
||||
|
||||
Based on test plan objectives:
|
||||
1. CDC noise is actually being generated (not all Gaussian fallback)
|
||||
2. Eigenvalues are valid (non-negative, bounded)
|
||||
3. CDC components are finite and usable for noise generation
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16, # Reduced to match batch size
|
||||
d_cdc=8,
|
||||
gamma=1.0,
|
||||
debug=True,
|
||||
adaptive_k=True
|
||||
)
|
||||
|
||||
# Set a fixed random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate batch of latents
|
||||
batch_size = 32
|
||||
latents = torch.randn(batch_size, *latent_sizes)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add latents with metadata
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'noise_dist_image_{i}'}
|
||||
)
|
||||
|
||||
# Compute CDC results
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Analyze noise properties
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
# Track samples that used CDC vs Gaussian fallback
|
||||
cdc_samples = 0
|
||||
gaussian_samples = 0
|
||||
eigenvalue_stats = {
|
||||
'min': float('inf'),
|
||||
'max': float('-inf'),
|
||||
'mean': 0.0,
|
||||
'sum': 0.0
|
||||
}
|
||||
|
||||
# Verify each sample's CDC components
|
||||
for i in range(batch_size):
|
||||
image_key = f'noise_dist_image_{i}'
|
||||
|
||||
# Get eigenvectors and eigenvalues
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
|
||||
|
||||
# Skip zero eigenvectors (fallback case)
|
||||
if torch.all(eigvecs[0] == 0):
|
||||
gaussian_samples += 1
|
||||
continue
|
||||
|
||||
# Get the top d_cdc eigenvectors and eigenvalues
|
||||
top_eigvecs = eigvecs[0] # (d_cdc, d)
|
||||
top_eigvals = eigvals[0] # (d_cdc,)
|
||||
|
||||
# Basic validity checks
|
||||
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
|
||||
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
|
||||
|
||||
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
|
||||
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
|
||||
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
|
||||
|
||||
# Update statistics
|
||||
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
|
||||
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
|
||||
eigenvalue_stats['sum'] += top_eigvals.sum().item()
|
||||
|
||||
cdc_samples += 1
|
||||
|
||||
# Compute mean eigenvalue across all CDC samples
|
||||
if cdc_samples > 0:
|
||||
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
|
||||
|
||||
# Print final statistics
|
||||
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
|
||||
print(f" CDC samples: {cdc_samples}/{batch_size}")
|
||||
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
|
||||
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
|
||||
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
|
||||
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
|
||||
|
||||
# Assertions based on plan objectives
|
||||
# 1. CDC noise should be generated for most samples
|
||||
assert cdc_samples > 0, "No samples used CDC noise injection"
|
||||
assert gaussian_samples < batch_size // 2, (
|
||||
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
|
||||
)
|
||||
|
||||
# 2. Eigenvalues should be valid (non-negative and bounded)
|
||||
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
|
||||
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
|
||||
|
||||
# 3. Mean eigenvalue should be reasonable (not degenerate)
|
||||
assert eigenvalue_stats['mean'] > 0.05, (
|
||||
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
|
||||
"suggests degenerate CDC components"
|
||||
)
|
||||
|
||||
def test_interpolation_reconstruction(self):
|
||||
"""
|
||||
Compare interpolation vs pad/truncate reconstruction methods for CDC.
|
||||
"""
|
||||
# Create test latents with different sizes - deterministic
|
||||
latent_small = torch.zeros(16, 4, 4)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
|
||||
|
||||
latent_large = torch.zeros(16, 8, 8)
|
||||
for c in range(16):
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
|
||||
|
||||
target_h, target_w = 6, 6 # Median size
|
||||
|
||||
# Method 1: Interpolation
|
||||
def interpolate_method(latent, target_h, target_w):
|
||||
latent_input = latent.unsqueeze(0) # (1, C, H, W)
|
||||
latent_resized = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
)
|
||||
# Resize back
|
||||
C, H, W = latent.shape
|
||||
latent_reconstructed = F.interpolate(
|
||||
latent_resized, size=(H, W), mode='bilinear', align_corners=False
|
||||
)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Method 2: Pad/Truncate
|
||||
def pad_truncate_method(latent, target_h, target_w):
|
||||
C, H, W = latent.shape
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
current_dim = C * H * W
|
||||
|
||||
if current_dim == target_dim:
|
||||
latent_resized_flat = latent_flat
|
||||
elif current_dim > target_dim:
|
||||
# Truncate
|
||||
latent_resized_flat = latent_flat[:target_dim]
|
||||
else:
|
||||
# Pad
|
||||
latent_resized_flat = torch.zeros(target_dim)
|
||||
latent_resized_flat[:current_dim] = latent_flat
|
||||
|
||||
# Resize back
|
||||
if current_dim == target_dim:
|
||||
latent_reconstructed_flat = latent_resized_flat
|
||||
elif current_dim > target_dim:
|
||||
# Pad back
|
||||
latent_reconstructed_flat = torch.zeros(current_dim)
|
||||
latent_reconstructed_flat[:target_dim] = latent_resized_flat
|
||||
else:
|
||||
# Truncate back
|
||||
latent_reconstructed_flat = latent_resized_flat[:current_dim]
|
||||
|
||||
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Compare for small latent (needs padding)
|
||||
interp_error_small = interpolate_method(latent_small, target_h, target_w)
|
||||
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
|
||||
|
||||
# Compare for large latent (needs truncation)
|
||||
interp_error_large = interpolate_method(latent_large, target_h, target_w)
|
||||
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Reconstruction Error Comparison")
|
||||
print("=" * 60)
|
||||
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
||||
print(f" Interpolation error: {interp_error_small:.6f}")
|
||||
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
||||
if pad_error_small > 0:
|
||||
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
||||
else:
|
||||
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
||||
print(" BUT the intermediate representation is corrupted with zeros!")
|
||||
|
||||
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
||||
print(f" Interpolation error: {interp_error_large:.6f}")
|
||||
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
||||
if truncate_error_large > 0:
|
||||
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
|
||||
|
||||
print("\nKey insight: For CDC, intermediate representation quality matters,")
|
||||
print("not reconstruction error. Interpolation preserves spatial structure.")
|
||||
|
||||
# Verify interpolation errors are reasonable
|
||||
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
|
||||
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
|
||||
|
||||
def test_spatial_structure_preservation(self):
|
||||
"""
|
||||
Test that interpolation preserves spatial structure better than pad/truncate.
|
||||
"""
|
||||
# Create a latent with clear spatial pattern (gradient)
|
||||
C, H, W = 16, 4, 4
|
||||
latent = torch.zeros(C, H, W)
|
||||
for i in range(H):
|
||||
for j in range(W):
|
||||
latent[:, i, j] = i * W + j # Gradient pattern
|
||||
|
||||
target_h, target_w = 6, 6
|
||||
|
||||
# Interpolation
|
||||
latent_input = latent.unsqueeze(0)
|
||||
latent_interp = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Pad/truncate
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
latent_padded = torch.zeros(target_dim)
|
||||
latent_padded[:len(latent_flat)] = latent_flat
|
||||
latent_pad = latent_padded.reshape(C, target_h, target_w)
|
||||
|
||||
# Check gradient preservation
|
||||
# For interpolation, adjacent pixels should have smooth gradients
|
||||
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
|
||||
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
|
||||
|
||||
# For padding, there will be abrupt changes (gradient to zero)
|
||||
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
|
||||
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Spatial Structure Preservation")
|
||||
print("=" * 60)
|
||||
print("\nGradient smoothness (lower is smoother):")
|
||||
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
||||
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
||||
|
||||
# Padding introduces larger gradients due to abrupt zeros
|
||||
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
|
||||
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure performance benchmarking markers
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"performance: mark test to verify CDC-FM computational performance"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"noise_distribution: mark test to verify noise injection properties"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"interpolation: mark test to verify interpolation quality"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -29,7 +29,8 @@ class TestCDCPreprocessorIntegration:
|
||||
Test basic CDC preprocessing with small dataset
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
# Add 10 small latents
|
||||
@@ -51,8 +52,9 @@ class TestCDCPreprocessorIntegration:
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify first CDC file structure
|
||||
cdc_path = tmp_path / "test_image_0_0004x0004_flux_cdc.npz"
|
||||
# Verify first CDC file structure (with config hash)
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
|
||||
assert cdc_path.exists()
|
||||
|
||||
import numpy as np
|
||||
@@ -73,7 +75,8 @@ class TestCDCPreprocessorIntegration:
|
||||
Test CDC preprocessing with variable-size latents (bucketing)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
@@ -109,9 +112,15 @@ class TestCDCPreprocessorIntegration:
|
||||
assert files_saved == 10
|
||||
|
||||
import numpy as np
|
||||
# Check shapes are stored in individual files
|
||||
data_0 = np.load(tmp_path / "test_image_0_0004x0004_flux_cdc.npz")
|
||||
data_5 = np.load(tmp_path / "test_image_5_0008x0008_flux_cdc.npz")
|
||||
# Check shapes are stored in individual files (with config hash)
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
|
||||
)
|
||||
data_0 = np.load(cdc_path_0)
|
||||
data_5 = np.load(cdc_path_5)
|
||||
|
||||
assert tuple(data_0['shape']) == (16, 4, 4)
|
||||
assert tuple(data_5['shape']) == (16, 8, 8)
|
||||
@@ -128,7 +137,8 @@ class TestDeviceConsistency:
|
||||
"""
|
||||
# Create CDC cache on CPU
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
@@ -148,7 +158,7 @@ class TestDeviceConsistency:
|
||||
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(device="cpu")
|
||||
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
|
||||
@@ -175,7 +185,8 @@ class TestDeviceConsistency:
|
||||
"""
|
||||
# Create CDC cache on CPU
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
@@ -195,7 +206,7 @@ class TestDeviceConsistency:
|
||||
|
||||
preprocessor.compute_all()
|
||||
|
||||
dataset = GammaBDataset(device="cpu")
|
||||
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Create noise and timesteps
|
||||
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
@@ -236,7 +247,8 @@ class TestCDCEndToEnd:
|
||||
"""
|
||||
# Step 1: Preprocess latents
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
@@ -257,8 +269,8 @@ class TestCDCEndToEnd:
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == num_samples
|
||||
|
||||
# Step 2: Load with GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(device="cpu")
|
||||
# Step 2: Load with GammaBDataset (use config hash)
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
"""
|
||||
Tests to validate the CDC rescaling recommendations from paper review.
|
||||
|
||||
These tests check:
|
||||
1. Gamma parameter interaction with rescaling
|
||||
2. Spatial adaptivity of eigenvalue scaling
|
||||
3. Verification of fixed vs adaptive rescaling behavior
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestGammaRescalingInteraction:
|
||||
"""Test that gamma parameter works correctly with eigenvalue rescaling"""
|
||||
|
||||
def test_gamma_scales_eigenvalues_correctly(self, tmp_path):
|
||||
"""Verify gamma multiplier is applied correctly after rescaling"""
|
||||
# Create two preprocessors with different gamma values
|
||||
gamma_values = [0.5, 1.0, 2.0]
|
||||
eigenvalue_results = {}
|
||||
|
||||
for gamma in gamma_values:
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu"
|
||||
)
|
||||
|
||||
# Add identical deterministic data for all runs
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / f"test_gamma_{gamma}.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Extract eigenvalues
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0").numpy()
|
||||
eigenvalue_results[gamma] = eigvals
|
||||
|
||||
# With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound
|
||||
# Gamma 0.5: max eigenvalue should be ~0.5
|
||||
# Gamma 1.0: max eigenvalue should be ~1.0
|
||||
# Gamma 2.0: max eigenvalue should be ~2.0
|
||||
|
||||
max_0p5 = np.max(eigenvalue_results[0.5])
|
||||
max_1p0 = np.max(eigenvalue_results[1.0])
|
||||
max_2p0 = np.max(eigenvalue_results[2.0])
|
||||
|
||||
assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}"
|
||||
assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}"
|
||||
assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}"
|
||||
|
||||
# All should have min of 1e-3 (clamp lower bound)
|
||||
assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3
|
||||
assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3
|
||||
assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3
|
||||
|
||||
print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}")
|
||||
print(f"✓ Gamma 1.0 max: {max_1p0:.4f}")
|
||||
print(f"✓ Gamma 2.0 max: {max_2p0:.4f}")
|
||||
|
||||
def test_large_gamma_maintains_reasonable_scale(self, tmp_path):
|
||||
"""Verify that large gamma values don't cause eigenvalue explosion"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_large_gamma.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
max_eigval = np.max(all_eigvals)
|
||||
mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6])
|
||||
|
||||
# With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0
|
||||
# But they should still be reasonable (not exploding)
|
||||
assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma"
|
||||
assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma"
|
||||
|
||||
print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}")
|
||||
|
||||
|
||||
class TestSpatialAdaptivityOfRescaling:
|
||||
"""Test spatial variation in eigenvalue scaling"""
|
||||
|
||||
def test_eigenvalues_vary_spatially(self, tmp_path):
|
||||
"""Verify eigenvalues differ across spatially separated clusters"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create two distinct clusters in latent space
|
||||
# Cluster 1: Tight cluster (low variance) - deterministic spread
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4)
|
||||
# Small variation around 0
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Cluster 2: Loose cluster (high variance) - deterministic spread
|
||||
for i in range(10, 20):
|
||||
latent = torch.ones(16, 4, 4) * 5.0
|
||||
# Large variation around 5.0
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_spatial_variation.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
# Get eigenvalues from both clusters
|
||||
cluster1_eigvals = []
|
||||
cluster2_eigvals = []
|
||||
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
cluster1_eigvals.append(np.max(eigvals))
|
||||
|
||||
for i in range(10, 20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
cluster2_eigvals.append(np.max(eigvals))
|
||||
|
||||
cluster1_mean = np.mean(cluster1_eigvals)
|
||||
cluster2_mean = np.mean(cluster2_eigvals)
|
||||
|
||||
print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}")
|
||||
print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}")
|
||||
|
||||
# With fixed target_scale rescaling, eigenvalues should be similar
|
||||
# despite different local geometry
|
||||
# This demonstrates the limitation of fixed rescaling
|
||||
ratio = cluster2_mean / (cluster1_mean + 1e-10)
|
||||
print(f"✓ Ratio (loose/tight): {ratio:.2f}")
|
||||
|
||||
# Both should be rescaled to similar magnitude (~0.1 due to target_scale)
|
||||
assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range"
|
||||
assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range"
|
||||
|
||||
|
||||
class TestFixedVsAdaptiveRescaling:
|
||||
"""Compare current fixed rescaling vs paper's adaptive approach"""
|
||||
|
||||
def test_current_rescaling_is_uniform(self, tmp_path):
|
||||
"""Demonstrate that current rescaling produces uniform eigenvalue scales"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create samples with varying local density - deterministic
|
||||
for i in range(20):
|
||||
latent = torch.zeros(16, 4, 4)
|
||||
# Some samples clustered, some isolated
|
||||
if i < 10:
|
||||
# Dense cluster around origin
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05
|
||||
else:
|
||||
# Isolated points - larger offset
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_uniform_rescaling.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
max_eigenvalues = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
vals = eigvals[eigvals > 1e-6]
|
||||
if vals.size: # at least one valid eigen-value
|
||||
max_eigenvalues.append(vals.max())
|
||||
|
||||
if not max_eigenvalues: # safeguard against empty list
|
||||
pytest.skip("no valid eigen-values found")
|
||||
|
||||
max_eigenvalues = np.array(max_eigenvalues)
|
||||
|
||||
# Check coefficient of variation (std / mean)
|
||||
cv = max_eigenvalues.std() / max_eigenvalues.mean()
|
||||
|
||||
print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]")
|
||||
print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}")
|
||||
print(f"✓ Coefficient of variation: {cv:.4f}")
|
||||
|
||||
# With clamping, eigenvalues should have relatively low variation
|
||||
assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping"
|
||||
# Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0])
|
||||
assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -1,132 +1,176 @@
|
||||
"""
|
||||
Standalone tests for CDC-FM integration.
|
||||
Standalone tests for CDC-FM per-file caching.
|
||||
|
||||
These tests focus on CDC-FM specific functionality without importing
|
||||
the full training infrastructure that has problematic dependencies.
|
||||
These tests focus on the current CDC-FM per-file caching implementation
|
||||
with hash-based cache validation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
import numpy as np
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestCDCPreprocessor:
|
||||
"""Test CDC preprocessing functionality"""
|
||||
"""Test CDC preprocessing functionality with per-file caching"""
|
||||
|
||||
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
|
||||
"""Test basic CDC preprocessing with small dataset"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
# Compute and save (creates per-file CDC caches)
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify file was created
|
||||
assert Path(result_path).exists()
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify structure
|
||||
from safetensors import safe_open
|
||||
# Verify first CDC file structure
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
|
||||
assert cdc_path.exists()
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
assert f.get_tensor("metadata/num_samples").item() == 10
|
||||
assert f.get_tensor("metadata/k_neighbors").item() == 5
|
||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||
data = np.load(cdc_path)
|
||||
assert data['k_neighbors'] == 5
|
||||
assert data['d_cdc'] == 4
|
||||
|
||||
# Check first sample
|
||||
eigvecs = f.get_tensor("eigenvectors/test_image_0")
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0")
|
||||
# Check eigenvectors and eigenvalues
|
||||
eigvecs = data['eigenvectors']
|
||||
eigvals = data['eigenvalues']
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
|
||||
def test_cdc_preprocessor_different_shapes(self, tmp_path):
|
||||
"""Test CDC preprocessing with variable-size latents (bucketing)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
|
||||
# Verify both shape groups were processed
|
||||
from safetensors import safe_open
|
||||
assert files_saved == 10
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check shapes are stored
|
||||
shape_0 = f.get_tensor("shapes/test_image_0")
|
||||
shape_5 = f.get_tensor("shapes/test_image_5")
|
||||
# Check shapes are stored in individual files
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
|
||||
)
|
||||
|
||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||
data_0 = np.load(cdc_path_0)
|
||||
data_5 = np.load(cdc_path_5)
|
||||
|
||||
assert tuple(data_0['shape']) == (16, 4, 4)
|
||||
assert tuple(data_5['shape']) == (16, 8, 8)
|
||||
|
||||
|
||||
class TestGammaBDataset:
|
||||
"""Test GammaBDataset loading and retrieval"""
|
||||
"""Test GammaBDataset loading and retrieval with per-file caching"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cdc_cache(self, tmp_path):
|
||||
"""Create a sample CDC cache file for testing"""
|
||||
cache_path = tmp_path / "test_gamma_b.safetensors"
|
||||
"""Create sample CDC cache files for testing"""
|
||||
# Use 20 samples to ensure proper k-NN computation
|
||||
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)],
|
||||
adaptive_k=True, # Enable adaptive k for small dataset
|
||||
min_bucket_size=5
|
||||
)
|
||||
|
||||
# Create mock Γ_b data for 5 samples
|
||||
tensors = {
|
||||
"metadata/num_samples": torch.tensor([5]),
|
||||
"metadata/k_neighbors": torch.tensor([10]),
|
||||
"metadata/d_cdc": torch.tensor([4]),
|
||||
"metadata/gamma": torch.tensor([1.0]),
|
||||
}
|
||||
# Create 20 samples
|
||||
latents_npz_paths = []
|
||||
for i in range(20):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
|
||||
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_{i}'}
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Add shape and CDC data for each sample
|
||||
for i in range(5):
|
||||
tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W
|
||||
tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d
|
||||
tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive
|
||||
|
||||
save_file(tensors, str(cache_path))
|
||||
return cache_path
|
||||
preprocessor.compute_all()
|
||||
return tmp_path, latents_npz_paths, preprocessor.config_hash
|
||||
|
||||
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
|
||||
"""Test that GammaBDataset loads metadata correctly"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
"""Test that GammaBDataset loads CDC files correctly"""
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
assert gamma_b_dataset.num_samples == 5
|
||||
assert gamma_b_dataset.d_cdc == 4
|
||||
# Get components for first sample
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu")
|
||||
|
||||
# Check shapes
|
||||
assert eigvecs.shape[0] == 1 # batch size
|
||||
assert eigvecs.shape[1] == 4 # d_cdc
|
||||
assert eigvals.shape == (1, 4) # batch, d_cdc
|
||||
|
||||
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
|
||||
"""Test retrieving Γ_b^(1/2) components"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Get Γ_b for indices [0, 2, 4]
|
||||
indices = [0, 2, 4]
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu")
|
||||
# Get Γ_b for paths [0, 2, 4]
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
|
||||
# Check shapes
|
||||
assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d)
|
||||
assert eigenvectors.shape[0] == 3 # batch
|
||||
assert eigenvectors.shape[1] == 4 # d_cdc
|
||||
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
|
||||
|
||||
# Check values are positive
|
||||
@@ -134,14 +178,16 @@ class TestGammaBDataset:
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns x unchanged at t=0"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Create test latents (batch of 3, matching d=1024 flattened)
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.zeros(3) # t = 0 for all samples
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu")
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -150,13 +196,15 @@ class TestGammaBDataset:
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns correct shape"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
x = torch.randn(2, 1024) # B, d (flattened)
|
||||
t = torch.tensor([0.3, 0.7])
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu")
|
||||
paths = [latents_npz_paths[1], latents_npz_paths[3]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -165,13 +213,15 @@ class TestGammaBDataset:
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x produces finite values"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.rand(3) # Random timesteps in [0, 1]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu")
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -187,31 +237,39 @@ class TestCDCEndToEnd:
|
||||
"""Test complete workflow: preprocess -> save -> load -> use"""
|
||||
# Step 1: Preprocess latents
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
|
||||
dataset_dirs=[str(tmp_path)]
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
latents_npz_paths = []
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
|
||||
latents_npz_paths.append(latents_npz_path)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
latents_npz_path=latents_npz_path,
|
||||
shape=latent.shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
files_saved = preprocessor.compute_all()
|
||||
assert files_saved == num_samples
|
||||
|
||||
# Step 2: Load with GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
assert gamma_b_dataset.num_samples == num_samples
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||||
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(image_keys, device="cpu")
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu")
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
"""
|
||||
Test warning throttling for CDC shape mismatches.
|
||||
|
||||
Ensures that duplicate warnings for the same sample are not logged repeatedly.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
|
||||
|
||||
|
||||
class TestWarningThrottling:
|
||||
"""Test that shape mismatch warnings are throttled"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_warned_samples(self):
|
||||
"""Clear the warned samples set before each test"""
|
||||
_cdc_warned_samples.clear()
|
||||
yield
|
||||
_cdc_warned_samples.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def cdc_cache(self, tmp_path):
|
||||
"""Create a test CDC cache with one shape"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create cache with one specific shape
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=preprocessed_shape, metadata=metadata)
|
||||
|
||||
cache_path = tmp_path / "test_throttle.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
return cache_path
|
||||
|
||||
def test_warning_only_logged_once_per_sample(self, cdc_cache, caplog):
|
||||
"""
|
||||
Test that shape mismatch warning is only logged once per sample.
|
||||
|
||||
Even if the same sample appears in multiple batches, only warn once.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
# Use different shape at runtime to trigger mismatch
|
||||
runtime_shape = (16, 64, 64)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
image_keys = ['test_image_0'] # Same sample
|
||||
|
||||
# First call - should warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise1 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise1,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have exactly one warning
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 1, "First call should produce exactly one warning"
|
||||
assert "CDC shape mismatch" in warnings[0].message
|
||||
|
||||
# Second call with same sample - should NOT warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise2 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise2,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have NO warnings
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 0, "Second call with same sample should not warn"
|
||||
|
||||
# Third call with same sample - still should NOT warn
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise3 = torch.randn(1, *runtime_shape, dtype=torch.float32)
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise3,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 0, "Third call should still not warn"
|
||||
|
||||
def test_different_samples_each_get_one_warning(self, cdc_cache, caplog):
|
||||
"""
|
||||
Test that different samples each get their own warning.
|
||||
|
||||
Each unique sample should be warned about once.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
runtime_shape = (16, 64, 64)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0], dtype=torch.float32)
|
||||
|
||||
# First batch: samples 0, 1, 2
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have 3 warnings (one per sample)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 3, "Should warn for each of the 3 samples"
|
||||
|
||||
# Second batch: same samples 0, 1, 2
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(3, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_0', 'test_image_1', 'test_image_2']
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have NO warnings (already warned)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 0, "Should not warn again for same samples"
|
||||
|
||||
# Third batch: new samples 3, 4
|
||||
with caplog.at_level(logging.WARNING):
|
||||
caplog.clear()
|
||||
noise = torch.randn(2, *runtime_shape, dtype=torch.float32)
|
||||
image_keys = ['test_image_3', 'test_image_4']
|
||||
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32)
|
||||
|
||||
_ = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
image_keys=image_keys,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Should have 2 warnings (new samples)
|
||||
warnings = [rec for rec in caplog.records if rec.levelname == "WARNING"]
|
||||
assert len(warnings) == 2, "Should warn for each of the 2 new samples"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user