mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes - Simplify test imports by removing unused modules - Minor formatting improvements in print statements - Ensure log messages provide clear context about dimension mismatches
This commit is contained in:
@@ -354,9 +354,11 @@ class LatentBatcher:
|
||||
Dict mapping exact_shape -> list of samples with that shape
|
||||
"""
|
||||
batches = {}
|
||||
shapes = set()
|
||||
|
||||
for sample in self.samples:
|
||||
shape_key = sample.shape
|
||||
shapes.add(shape_key)
|
||||
|
||||
# Group by exact shape only - no aspect ratio grouping or resizing
|
||||
if shape_key not in batches:
|
||||
@@ -364,6 +366,15 @@ class LatentBatcher:
|
||||
|
||||
batches[shape_key].append(sample)
|
||||
|
||||
# If more than one unique shape, log a warning
|
||||
if len(shapes) > 1:
|
||||
logger.warning(
|
||||
"Dimension mismatch: %d unique shapes detected. "
|
||||
"Shapes: %s. Using Gaussian fallback for these samples.",
|
||||
len(shapes),
|
||||
shapes
|
||||
)
|
||||
|
||||
return batches
|
||||
|
||||
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
|
||||
|
||||
@@ -6,8 +6,6 @@ Verifies that adaptive k properly adjusts based on bucket sizes.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
183
tests/library/test_cdc_advanced.py
Normal file
183
tests/library/test_cdc_advanced.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
|
||||
class MockGammaBDataset:
|
||||
"""
|
||||
Mock implementation of GammaBDataset for testing gradient flow
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Simple initialization that doesn't require file loading
|
||||
"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
eigenvalues: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
t: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simplified implementation of compute_sigma_t_x for testing
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
# Flatten x if it's 4D
|
||||
if x.dim() == 4:
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, -1) # (B, C*H*W)
|
||||
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Validate dimensions
|
||||
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
|
||||
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
|
||||
|
||||
# Early return for t=0 with gradient preservation
|
||||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Compute Σ_t @ x
|
||||
# V^T x
|
||||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||||
|
||||
# sqrt(λ) * V^T x
|
||||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||||
|
||||
# V @ (sqrt(λ) * V^T x)
|
||||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||||
|
||||
# Interpolate between original and noisy latent
|
||||
result = (1 - t) * x + t * gamma_sqrt_x
|
||||
|
||||
# Restore original shape
|
||||
result = result.reshape(orig_shape)
|
||||
|
||||
return result
|
||||
|
||||
class TestCDCAdvanced:
|
||||
def setup_method(self):
|
||||
"""Prepare consistent test environment"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def test_gradient_flow_preservation(self):
|
||||
"""
|
||||
Verify that gradient flow is preserved even for near-zero time steps
|
||||
with learnable time embeddings
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create a learnable time embedding with small initial value
|
||||
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
assert t.grad is not None, "Time embedding gradient should be computed"
|
||||
assert latent.grad is not None, "Input latent gradient should be computed"
|
||||
|
||||
# Check gradient magnitudes are non-zero
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
|
||||
|
||||
# Optional: Print gradient details for debugging
|
||||
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
|
||||
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
|
||||
|
||||
def test_gradient_flow_with_different_time_steps(self):
|
||||
"""
|
||||
Verify gradient flow across different time step values
|
||||
"""
|
||||
# Test time steps
|
||||
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
|
||||
|
||||
for time_val in time_steps:
|
||||
# Create a learnable time embedding
|
||||
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Generate mock latent and CDC components
|
||||
batch_size, latent_dim = 4, 64
|
||||
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
||||
|
||||
# Create mock eigenvectors and eigenvalues
|
||||
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
||||
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
||||
|
||||
# Ensure eigenvectors and eigenvalues are meaningful
|
||||
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
||||
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
||||
|
||||
# Use the mock dataset
|
||||
mock_dataset = MockGammaBDataset()
|
||||
|
||||
# Compute noisy latent with gradient tracking
|
||||
noisy_latent = mock_dataset.compute_sigma_t_x(
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
latent,
|
||||
t
|
||||
)
|
||||
|
||||
# Compute a dummy loss to check gradient flow
|
||||
loss = noisy_latent.sum()
|
||||
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
|
||||
# Assertions to verify gradient flow
|
||||
t_grad_magnitude = torch.abs(t.grad).sum()
|
||||
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
||||
|
||||
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
|
||||
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
|
||||
|
||||
# Reset gradients for next iteration
|
||||
if t.grad is not None:
|
||||
t.grad.zero_()
|
||||
if latent.grad is not None:
|
||||
latent.grad.zero_()
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Add custom markers for CDC-FM tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
|
||||
)
|
||||
146
tests/library/test_cdc_dimension_handling.py
Normal file
146
tests/library/test_cdc_dimension_handling.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Test CDC-FM dimension handling and fallback mechanisms.
|
||||
|
||||
This module tests the behavior of the CDC Flow Matching implementation
|
||||
when encountering latents with different dimensions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
class TestDimensionHandling:
|
||||
def setup_method(self):
|
||||
"""Prepare consistent test environment"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def test_mixed_dimension_fallback(self):
|
||||
"""
|
||||
Verify that preprocessor falls back to standard noise for mixed-dimension batches
|
||||
"""
|
||||
# Prepare preprocessor with debug mode
|
||||
preprocessor = CDCPreprocessor(debug=True)
|
||||
|
||||
# Different-sized latents (3D: channels, height, width)
|
||||
latents = [
|
||||
torch.randn(3, 32, 64), # First latent: 3x32x64
|
||||
torch.randn(3, 32, 128), # Second latent: 3x32x128 (different dimension)
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
# Try adding mixed-dimension latents
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_mixed_image_{i}'}
|
||||
)
|
||||
|
||||
try:
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
except ValueError as e:
|
||||
# If implementation raises ValueError, that's acceptable
|
||||
assert "Dimension mismatch" in str(e)
|
||||
return
|
||||
|
||||
# Check for dimension-related log messages
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension mismatch" in msg.lower()
|
||||
]
|
||||
assert len(dimension_warnings) > 0, "No dimension-related warnings were logged"
|
||||
|
||||
# Load results and verify fallback
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
# Check metadata about samples with/without CDC
|
||||
assert dataset.num_samples == len(latents), "All samples should be processed"
|
||||
|
||||
def test_adaptive_k_with_dimension_constraints(self):
|
||||
"""
|
||||
Test adaptive k-neighbors behavior with dimension constraints
|
||||
"""
|
||||
# Prepare preprocessor with adaptive k and small bucket size
|
||||
preprocessor = CDCPreprocessor(
|
||||
adaptive_k=True,
|
||||
min_bucket_size=5,
|
||||
debug=True
|
||||
)
|
||||
|
||||
# Generate latents with similar but not identical dimensions
|
||||
base_latent = torch.randn(3, 32, 64)
|
||||
similar_latents = [
|
||||
base_latent,
|
||||
torch.randn(3, 32, 65), # Slightly different dimension
|
||||
torch.randn(3, 32, 66) # Another slightly different dimension
|
||||
]
|
||||
|
||||
# Use a mock handler to capture log messages
|
||||
from library.cdc_fm import logger
|
||||
|
||||
log_messages = []
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_messages.append(record.getMessage())
|
||||
|
||||
# Temporarily add a capture handler
|
||||
capture_handler = LogCapture()
|
||||
logger.addHandler(capture_handler)
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add similar latents
|
||||
for i, latent in enumerate(similar_latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'test_adaptive_k_image_{i}'}
|
||||
)
|
||||
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Load results
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
# Verify samples processed
|
||||
assert dataset.num_samples == len(similar_latents), "All samples should be processed"
|
||||
|
||||
# Optional: Check warnings about dimension differences
|
||||
dimension_warnings = [
|
||||
msg for msg in log_messages
|
||||
if "dimension" in msg.lower()
|
||||
]
|
||||
print(f"Dimension-related warnings: {dimension_warnings}")
|
||||
|
||||
finally:
|
||||
# Remove the capture handler
|
||||
logger.removeHandler(capture_handler)
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure custom markers for dimension handling tests
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"dimension_handling: mark test for CDC-FM dimension mismatch scenarios"
|
||||
)
|
||||
164
tests/library/test_cdc_eigenvalue_real_data.py
Normal file
164
tests/library/test_cdc_eigenvalue_real_data.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Tests using realistic high-dimensional data to catch scaling bugs.
|
||||
|
||||
This test uses realistic VAE-like latents to ensure eigenvalue normalization
|
||||
works correctly on real-world data.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestRealisticDataScaling:
|
||||
"""Test eigenvalue scaling with realistic high-dimensional data"""
|
||||
|
||||
def test_high_dimensional_latents_not_saturated(self, tmp_path):
|
||||
"""
|
||||
Verify that high-dimensional realistic latents don't saturate eigenvalues.
|
||||
|
||||
This test simulates real FLUX training data:
|
||||
- High dimension (16×64×64 = 65536)
|
||||
- Varied content (different variance in different regions)
|
||||
- Realistic magnitude (VAE output scale)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create 20 samples with realistic varied structure
|
||||
for i in range(20):
|
||||
# High-dimensional latent like FLUX
|
||||
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
|
||||
|
||||
# Create varied structure across the latent
|
||||
# Different channels have different patterns (realistic for VAE)
|
||||
for c in range(16):
|
||||
# Some channels have gradients
|
||||
if c < 4:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = (h + w) / 128.0
|
||||
# Some channels have patterns
|
||||
elif c < 8:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
|
||||
# Some channels are more uniform
|
||||
else:
|
||||
latent[c, :, :] = c * 0.1
|
||||
|
||||
# Add per-sample variation (different "subjects")
|
||||
latent = latent * (1.0 + i * 0.2)
|
||||
|
||||
# Add realistic VAE-like noise/variation
|
||||
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are NOT all saturated at 1.0
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical: eigenvalues should NOT all be 1.0
|
||||
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
|
||||
total = len(non_zero_eigvals)
|
||||
percent_at_max = (at_max / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
|
||||
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
|
||||
|
||||
# FAIL if too many eigenvalues are saturated at 1.0
|
||||
assert percent_at_max < 80, (
|
||||
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
|
||||
f"This indicates the normalization bug - raw eigenvalues are not being "
|
||||
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
|
||||
)
|
||||
|
||||
# Should have good diversity
|
||||
assert np.std(non_zero_eigvals) > 0.1, (
|
||||
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
|
||||
f"Should see diverse eigenvalues, not all the same value."
|
||||
)
|
||||
|
||||
# Mean should be in reasonable range (not all 1.0)
|
||||
mean_eigval = np.mean(non_zero_eigvals)
|
||||
assert 0.05 < mean_eigval < 0.9, (
|
||||
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
|
||||
f"If mean ≈ 1.0, eigenvalues are saturated."
|
||||
)
|
||||
|
||||
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
|
||||
"""
|
||||
Test that datasets with more variance produce more diverse eigenvalues.
|
||||
|
||||
This ensures the normalization preserves relative information.
|
||||
"""
|
||||
# Create two preprocessors with different data variance
|
||||
results = {}
|
||||
|
||||
for variance_scale in [0.5, 2.0]:
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(15):
|
||||
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
|
||||
|
||||
# Create varied patterns
|
||||
for c in range(16):
|
||||
for h in range(32):
|
||||
for w in range(32):
|
||||
latent[c, h, w] = (
|
||||
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
|
||||
)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
eigvals = []
|
||||
for i in range(15):
|
||||
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
eigvals.extend(ev[ev > 1e-6])
|
||||
|
||||
results[variance_scale] = {
|
||||
'mean': np.mean(eigvals),
|
||||
'std': np.std(eigvals),
|
||||
'range': (np.min(eigvals), np.max(eigvals))
|
||||
}
|
||||
|
||||
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
|
||||
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
|
||||
|
||||
# Both should have diversity (not saturated)
|
||||
for scale in [0.5, 2.0]:
|
||||
assert results[scale]['std'] > 0.1, (
|
||||
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -6,8 +6,6 @@ Ensures that gradients propagate correctly through both fast and slow paths.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation
|
||||
|
||||
@@ -4,7 +4,6 @@ Test comparing interpolation vs pad/truncate for CDC preprocessing.
|
||||
This test quantifies the difference between the two approaches.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -89,16 +88,16 @@ class TestInterpolationComparison:
|
||||
print("\n" + "=" * 60)
|
||||
print("Reconstruction Error Comparison")
|
||||
print("=" * 60)
|
||||
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
||||
print("\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
||||
print(f" Interpolation error: {interp_error_small:.6f}")
|
||||
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
||||
if pad_error_small > 0:
|
||||
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
||||
else:
|
||||
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
||||
print(f" BUT the intermediate representation is corrupted with zeros!")
|
||||
print(" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
||||
print(" BUT the intermediate representation is corrupted with zeros!")
|
||||
|
||||
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
||||
print("\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
||||
print(f" Interpolation error: {interp_error_large:.6f}")
|
||||
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
||||
if truncate_error_large > 0:
|
||||
@@ -151,7 +150,7 @@ class TestInterpolationComparison:
|
||||
print("\n" + "=" * 60)
|
||||
print("Spatial Structure Preservation")
|
||||
print("=" * 60)
|
||||
print(f"\nGradient smoothness (lower is smoother):")
|
||||
print("\nGradient smoothness (lower is smoother):")
|
||||
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
||||
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
||||
|
||||
|
||||
268
tests/library/test_cdc_performance.py
Normal file
268
tests/library/test_cdc_performance.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Performance benchmarking for CDC Flow Matching implementation.
|
||||
|
||||
This module tests the computational overhead and noise injection properties
|
||||
of the CDC-FM preprocessing pipeline.
|
||||
"""
|
||||
|
||||
import time
|
||||
import tempfile
|
||||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
class TestCDCPerformance:
|
||||
"""
|
||||
Performance and Noise Injection Verification Tests for CDC Flow Matching
|
||||
|
||||
These tests validate the computational performance and noise injection properties
|
||||
of the CDC-FM preprocessing pipeline across different latent sizes.
|
||||
|
||||
Key Verification Points:
|
||||
1. Computational efficiency for various latent dimensions
|
||||
2. Noise injection statistical properties
|
||||
3. Eigenvector and eigenvalue characteristics
|
||||
"""
|
||||
|
||||
@pytest.fixture(params=[
|
||||
(3, 32, 32), # Small latent: typical for compact representations
|
||||
(3, 64, 64), # Medium latent: standard feature maps
|
||||
(3, 128, 128) # Large latent: high-resolution feature spaces
|
||||
])
|
||||
def latent_sizes(self, request):
|
||||
"""
|
||||
Parametrized fixture generating test cases for different latent sizes.
|
||||
|
||||
Rationale:
|
||||
- Tests robustness across various computational scales
|
||||
- Ensures consistent behavior from compact to large representations
|
||||
- Identifies potential dimensionality-related performance bottlenecks
|
||||
"""
|
||||
return request.param
|
||||
|
||||
def test_computational_overhead(self, latent_sizes):
|
||||
"""
|
||||
Measure computational overhead of CDC preprocessing across latent sizes.
|
||||
|
||||
Performance Verification Objectives:
|
||||
1. Verify preprocessing time scales predictably with input dimensions
|
||||
2. Ensure adaptive k-neighbors works efficiently
|
||||
3. Validate computational overhead remains within acceptable bounds
|
||||
|
||||
Performance Metrics:
|
||||
- Total preprocessing time
|
||||
- Per-sample processing time
|
||||
- Computational complexity indicators
|
||||
|
||||
Args:
|
||||
latent_sizes (tuple): Latent dimensions (C, H, W) to benchmark
|
||||
"""
|
||||
# Tuned preprocessing configuration
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=256, # Comprehensive neighborhood exploration
|
||||
d_cdc=8, # Geometric embedding dimensionality
|
||||
debug=True, # Enable detailed performance logging
|
||||
adaptive_k=True # Dynamic neighborhood size adjustment
|
||||
)
|
||||
|
||||
# Set a fixed random seed for reproducibility
|
||||
torch.manual_seed(42) # Consistent random generation
|
||||
|
||||
# Generate representative latent batch
|
||||
batch_size = 32
|
||||
latents = torch.randn(batch_size, *latent_sizes)
|
||||
|
||||
# Precision timing of preprocessing
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add latents with traceable metadata
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'perf_test_image_{i}'}
|
||||
)
|
||||
|
||||
# Compute CDC results
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Calculate precise preprocessing metrics
|
||||
end_time = time.perf_counter()
|
||||
preprocessing_time = end_time - start_time
|
||||
per_sample_time = preprocessing_time / batch_size
|
||||
|
||||
# Performance reporting and assertions
|
||||
input_volume = np.prod(latent_sizes)
|
||||
time_complexity_indicator = preprocessing_time / input_volume
|
||||
|
||||
print(f"\nPerformance Breakdown:")
|
||||
print(f" Latent Size: {latent_sizes}")
|
||||
print(f" Total Samples: {batch_size}")
|
||||
print(f" Input Volume: {input_volume}")
|
||||
print(f" Total Time: {preprocessing_time:.4f} seconds")
|
||||
print(f" Per Sample Time: {per_sample_time:.6f} seconds")
|
||||
print(f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel")
|
||||
|
||||
# Adaptive thresholds based on input dimensions
|
||||
max_total_time = 10.0 # Base threshold
|
||||
max_per_sample_time = 2.0 # Per-sample time threshold (more lenient)
|
||||
|
||||
# Different time complexity thresholds for different latent sizes
|
||||
max_time_complexity = (
|
||||
1e-2 if np.prod(latent_sizes) <= 3072 else # Smaller latents
|
||||
1e-4 # Standard latents
|
||||
)
|
||||
|
||||
# Performance assertions with informative error messages
|
||||
assert preprocessing_time < max_total_time, (
|
||||
f"Total preprocessing time exceeded threshold!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Total Time: {preprocessing_time:.4f} seconds\n"
|
||||
f" Threshold: {max_total_time} seconds"
|
||||
)
|
||||
|
||||
assert per_sample_time < max_per_sample_time, (
|
||||
f"Per-sample processing time exceeded threshold!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Per Sample Time: {per_sample_time:.6f} seconds\n"
|
||||
f" Threshold: {max_per_sample_time} seconds"
|
||||
)
|
||||
|
||||
# More adaptable time complexity check
|
||||
assert time_complexity_indicator < max_time_complexity, (
|
||||
f"Time complexity scaling exceeded expectations!\n"
|
||||
f" Latent Size: {latent_sizes}\n"
|
||||
f" Input Volume: {input_volume}\n"
|
||||
f" Time/Volume Ratio: {time_complexity_indicator:.8f} seconds/voxel\n"
|
||||
f" Threshold: {max_time_complexity} seconds/voxel"
|
||||
)
|
||||
|
||||
def test_noise_distribution(self, latent_sizes):
|
||||
"""
|
||||
Verify CDC noise injection quality and properties.
|
||||
|
||||
Based on test plan objectives:
|
||||
1. CDC noise is actually being generated (not all Gaussian fallback)
|
||||
2. Eigenvalues are valid (non-negative, bounded)
|
||||
3. CDC components are finite and usable for noise generation
|
||||
|
||||
Args:
|
||||
latent_sizes (tuple): Latent dimensions (C, H, W)
|
||||
"""
|
||||
# Preprocessing configuration
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16, # Reduced to match batch size
|
||||
d_cdc=8,
|
||||
gamma=1.0,
|
||||
debug=True,
|
||||
adaptive_k=True
|
||||
)
|
||||
|
||||
# Set a fixed random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate batch of latents
|
||||
batch_size = 32
|
||||
latents = torch.randn(batch_size, *latent_sizes)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors') as tmp_file:
|
||||
# Add latents with metadata
|
||||
for i, latent in enumerate(latents):
|
||||
preprocessor.add_latent(
|
||||
latent,
|
||||
global_idx=i,
|
||||
metadata={'image_key': f'noise_dist_image_{i}'}
|
||||
)
|
||||
|
||||
# Compute CDC results
|
||||
cdc_path = preprocessor.compute_all(tmp_file.name)
|
||||
|
||||
# Analyze noise properties
|
||||
dataset = GammaBDataset(cdc_path)
|
||||
|
||||
# Track samples that used CDC vs Gaussian fallback
|
||||
cdc_samples = 0
|
||||
gaussian_samples = 0
|
||||
eigenvalue_stats = {
|
||||
'min': float('inf'),
|
||||
'max': float('-inf'),
|
||||
'mean': 0.0,
|
||||
'sum': 0.0
|
||||
}
|
||||
|
||||
# Verify each sample's CDC components
|
||||
for i in range(batch_size):
|
||||
image_key = f'noise_dist_image_{i}'
|
||||
|
||||
# Get eigenvectors and eigenvalues
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt([image_key])
|
||||
|
||||
# Skip zero eigenvectors (fallback case)
|
||||
if torch.all(eigvecs[0] == 0):
|
||||
gaussian_samples += 1
|
||||
continue
|
||||
|
||||
# Get the top d_cdc eigenvectors and eigenvalues
|
||||
top_eigvecs = eigvecs[0] # (d_cdc, d)
|
||||
top_eigvals = eigvals[0] # (d_cdc,)
|
||||
|
||||
# Basic validity checks
|
||||
assert torch.all(torch.isfinite(top_eigvecs)), f"Non-finite eigenvectors for sample {i}"
|
||||
assert torch.all(torch.isfinite(top_eigvals)), f"Non-finite eigenvalues for sample {i}"
|
||||
|
||||
# Eigenvalue bounds (should be positive and <= 1.0 based on CDC-FM)
|
||||
assert torch.all(top_eigvals >= 0), f"Negative eigenvalues for sample {i}: {top_eigvals}"
|
||||
assert torch.all(top_eigvals <= 1.0), f"Eigenvalues exceed 1.0 for sample {i}: {top_eigvals}"
|
||||
|
||||
# Update statistics
|
||||
eigenvalue_stats['min'] = min(eigenvalue_stats['min'], top_eigvals.min().item())
|
||||
eigenvalue_stats['max'] = max(eigenvalue_stats['max'], top_eigvals.max().item())
|
||||
eigenvalue_stats['sum'] += top_eigvals.sum().item()
|
||||
|
||||
cdc_samples += 1
|
||||
|
||||
# Compute mean eigenvalue across all CDC samples
|
||||
if cdc_samples > 0:
|
||||
eigenvalue_stats['mean'] = eigenvalue_stats['sum'] / (cdc_samples * 8) # 8 = d_cdc
|
||||
|
||||
# Print final statistics
|
||||
print(f"\nNoise Distribution Results for latent size {latent_sizes}:")
|
||||
print(f" CDC samples: {cdc_samples}/{batch_size}")
|
||||
print(f" Gaussian fallback: {gaussian_samples}/{batch_size}")
|
||||
print(f" Eigenvalue min: {eigenvalue_stats['min']:.4f}")
|
||||
print(f" Eigenvalue max: {eigenvalue_stats['max']:.4f}")
|
||||
print(f" Eigenvalue mean: {eigenvalue_stats['mean']:.4f}")
|
||||
|
||||
# Assertions based on plan objectives
|
||||
|
||||
# 1. CDC noise should be generated for most samples
|
||||
assert cdc_samples > 0, "No samples used CDC noise injection"
|
||||
assert gaussian_samples < batch_size // 2, (
|
||||
f"Too many samples fell back to Gaussian noise: {gaussian_samples}/{batch_size}"
|
||||
)
|
||||
|
||||
# 2. Eigenvalues should be valid (non-negative and bounded)
|
||||
assert eigenvalue_stats['min'] >= 0, "Eigenvalues should be non-negative"
|
||||
assert eigenvalue_stats['max'] <= 1.0, "Maximum eigenvalue exceeds 1.0"
|
||||
|
||||
# 3. Mean eigenvalue should be reasonable (not degenerate)
|
||||
assert eigenvalue_stats['mean'] > 0.05, (
|
||||
f"Mean eigenvalue too low ({eigenvalue_stats['mean']:.4f}), "
|
||||
"suggests degenerate CDC components"
|
||||
)
|
||||
|
||||
def pytest_configure(config):
|
||||
"""
|
||||
Configure performance benchmarking markers
|
||||
"""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"performance: mark test to verify CDC-FM computational performance"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"noise_distribution: mark test to verify noise injection properties"
|
||||
)
|
||||
237
tests/library/test_cdc_rescaling_recommendations.py
Normal file
237
tests/library/test_cdc_rescaling_recommendations.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Tests to validate the CDC rescaling recommendations from paper review.
|
||||
|
||||
These tests check:
|
||||
1. Gamma parameter interaction with rescaling
|
||||
2. Spatial adaptivity of eigenvalue scaling
|
||||
3. Verification of fixed vs adaptive rescaling behavior
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestGammaRescalingInteraction:
|
||||
"""Test that gamma parameter works correctly with eigenvalue rescaling"""
|
||||
|
||||
def test_gamma_scales_eigenvalues_correctly(self, tmp_path):
|
||||
"""Verify gamma multiplier is applied correctly after rescaling"""
|
||||
# Create two preprocessors with different gamma values
|
||||
gamma_values = [0.5, 1.0, 2.0]
|
||||
eigenvalue_results = {}
|
||||
|
||||
for gamma in gamma_values:
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=gamma, device="cpu"
|
||||
)
|
||||
|
||||
# Add identical deterministic data for all runs
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.1
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / f"test_gamma_{gamma}.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Extract eigenvalues
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
eigvals = f.get_tensor("eigenvalues/test_image_0").numpy()
|
||||
eigenvalue_results[gamma] = eigvals
|
||||
|
||||
# With clamping to [1e-3, gamma*1.0], verify gamma changes the upper bound
|
||||
# Gamma 0.5: max eigenvalue should be ~0.5
|
||||
# Gamma 1.0: max eigenvalue should be ~1.0
|
||||
# Gamma 2.0: max eigenvalue should be ~2.0
|
||||
|
||||
max_0p5 = np.max(eigenvalue_results[0.5])
|
||||
max_1p0 = np.max(eigenvalue_results[1.0])
|
||||
max_2p0 = np.max(eigenvalue_results[2.0])
|
||||
|
||||
assert max_0p5 <= 0.5 + 0.01, f"Gamma 0.5 max should be ≤0.5, got {max_0p5}"
|
||||
assert max_1p0 <= 1.0 + 0.01, f"Gamma 1.0 max should be ≤1.0, got {max_1p0}"
|
||||
assert max_2p0 <= 2.0 + 0.01, f"Gamma 2.0 max should be ≤2.0, got {max_2p0}"
|
||||
|
||||
# All should have min of 1e-3 (clamp lower bound)
|
||||
assert np.min(eigenvalue_results[0.5][eigenvalue_results[0.5] > 0]) >= 1e-3
|
||||
assert np.min(eigenvalue_results[1.0][eigenvalue_results[1.0] > 0]) >= 1e-3
|
||||
assert np.min(eigenvalue_results[2.0][eigenvalue_results[2.0] > 0]) >= 1e-3
|
||||
|
||||
print(f"\n✓ Gamma 0.5 max: {max_0p5:.4f}")
|
||||
print(f"✓ Gamma 1.0 max: {max_1p0:.4f}")
|
||||
print(f"✓ Gamma 2.0 max: {max_2p0:.4f}")
|
||||
|
||||
def test_large_gamma_maintains_reasonable_scale(self, tmp_path):
|
||||
"""Verify that large gamma values don't cause eigenvalue explosion"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=10.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.15
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_large_gamma.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
max_eigval = np.max(all_eigvals)
|
||||
mean_eigval = np.mean([e for e in all_eigvals if e > 1e-6])
|
||||
|
||||
# With gamma=10.0 and target_scale=0.1, eigenvalues should be ~1.0
|
||||
# But they should still be reasonable (not exploding)
|
||||
assert max_eigval < 100, f"Max eigenvalue {max_eigval} too large even with large gamma"
|
||||
assert mean_eigval <= 10, f"Mean eigenvalue {mean_eigval} too large even with large gamma"
|
||||
|
||||
print(f"\n✓ With gamma=10.0: max={max_eigval:.2f}, mean={mean_eigval:.2f}")
|
||||
|
||||
|
||||
class TestSpatialAdaptivityOfRescaling:
|
||||
"""Test spatial variation in eigenvalue scaling"""
|
||||
|
||||
def test_eigenvalues_vary_spatially(self, tmp_path):
|
||||
"""Verify eigenvalues differ across spatially separated clusters"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create two distinct clusters in latent space
|
||||
# Cluster 1: Tight cluster (low variance) - deterministic spread
|
||||
for i in range(10):
|
||||
latent = torch.zeros(16, 4, 4)
|
||||
# Small variation around 0
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 100.0 + i * 0.01
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
# Cluster 2: Loose cluster (high variance) - deterministic spread
|
||||
for i in range(10, 20):
|
||||
latent = torch.ones(16, 4, 4) * 5.0
|
||||
# Large variation around 5.0
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] += (c + h + w) / 10.0 + (i - 10) * 0.2
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_spatial_variation.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
# Get eigenvalues from both clusters
|
||||
cluster1_eigvals = []
|
||||
cluster2_eigvals = []
|
||||
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
cluster1_eigvals.append(np.max(eigvals))
|
||||
|
||||
for i in range(10, 20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
cluster2_eigvals.append(np.max(eigvals))
|
||||
|
||||
cluster1_mean = np.mean(cluster1_eigvals)
|
||||
cluster2_mean = np.mean(cluster2_eigvals)
|
||||
|
||||
print(f"\n✓ Tight cluster max eigenvalue: {cluster1_mean:.4f}")
|
||||
print(f"✓ Loose cluster max eigenvalue: {cluster2_mean:.4f}")
|
||||
|
||||
# With fixed target_scale rescaling, eigenvalues should be similar
|
||||
# despite different local geometry
|
||||
# This demonstrates the limitation of fixed rescaling
|
||||
ratio = cluster2_mean / (cluster1_mean + 1e-10)
|
||||
print(f"✓ Ratio (loose/tight): {ratio:.2f}")
|
||||
|
||||
# Both should be rescaled to similar magnitude (~0.1 due to target_scale)
|
||||
assert 0.01 < cluster1_mean < 10.0, "Cluster 1 eigenvalues out of expected range"
|
||||
assert 0.01 < cluster2_mean < 10.0, "Cluster 2 eigenvalues out of expected range"
|
||||
|
||||
|
||||
class TestFixedVsAdaptiveRescaling:
|
||||
"""Compare current fixed rescaling vs paper's adaptive approach"""
|
||||
|
||||
def test_current_rescaling_is_uniform(self, tmp_path):
|
||||
"""Demonstrate that current rescaling produces uniform eigenvalue scales"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create samples with varying local density - deterministic
|
||||
for i in range(20):
|
||||
latent = torch.zeros(16, 4, 4)
|
||||
# Some samples clustered, some isolated
|
||||
if i < 10:
|
||||
# Dense cluster around origin
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 40.0 + i * 0.05
|
||||
else:
|
||||
# Isolated points - larger offset
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 40.0 + i * 2.0
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_uniform_rescaling.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
max_eigenvalues = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
vals = eigvals[eigvals > 1e-6]
|
||||
if vals.size: # at least one valid eigen-value
|
||||
max_eigenvalues.append(vals.max())
|
||||
|
||||
if not max_eigenvalues: # safeguard against empty list
|
||||
pytest.skip("no valid eigen-values found")
|
||||
|
||||
max_eigenvalues = np.array(max_eigenvalues)
|
||||
|
||||
# Check coefficient of variation (std / mean)
|
||||
cv = max_eigenvalues.std() / max_eigenvalues.mean()
|
||||
|
||||
print(f"\n✓ Max eigenvalues range: [{np.min(max_eigenvalues):.4f}, {np.max(max_eigenvalues):.4f}]")
|
||||
print(f"✓ Mean: {np.mean(max_eigenvalues):.4f}, Std: {np.std(max_eigenvalues):.4f}")
|
||||
print(f"✓ Coefficient of variation: {cv:.4f}")
|
||||
|
||||
# With clamping, eigenvalues should have relatively low variation
|
||||
assert cv < 1.0, "Eigenvalues should have relatively low variation with clamping"
|
||||
# Mean should be reasonable (clamped to [1e-3, gamma*1.0] = [1e-3, 1.0])
|
||||
assert 0.01 < np.mean(max_eigenvalues) <= 1.0, f"Mean eigenvalue {np.mean(max_eigenvalues)} out of expected range"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -5,10 +5,8 @@ These tests focus on CDC-FM specific functionality without importing
|
||||
the full training infrastructure that has problematic dependencies.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
@@ -7,7 +7,6 @@ Ensures that duplicate warnings for the same sample are not logged repeatedly.
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation, _cdc_warned_samples
|
||||
|
||||
Reference in New Issue
Block a user