Improve dimension mismatch warning for CDC Flow Matching

- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
This commit is contained in:
rockerBOO
2025-10-11 17:17:09 -04:00
parent aa3a216106
commit 8089cb6925
11 changed files with 1014 additions and 13 deletions

View File

@@ -0,0 +1,164 @@
"""
Tests using realistic high-dimensional data to catch scaling bugs.
This test uses realistic VAE-like latents to ensure eigenvalue normalization
works correctly on real-world data.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestRealisticDataScaling:
"""Test eigenvalue scaling with realistic high-dimensional data"""
def test_high_dimensional_latents_not_saturated(self, tmp_path):
"""
Verify that high-dimensional realistic latents don't saturate eigenvalues.
This test simulates real FLUX training data:
- High dimension (16×64×64 = 65536)
- Varied content (different variance in different regions)
- Realistic magnitude (VAE output scale)
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
# Different channels have different patterns (realistic for VAE)
for c in range(16):
# Some channels have gradients
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
# Some channels have patterns
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
# Some channels are more uniform
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation (different "subjects")
latent = latent * (1.0 + i * 0.2)
# Add realistic VAE-like noise/variation
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are NOT all saturated at 1.0
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical: eigenvalues should NOT all be 1.0
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# FAIL if too many eigenvalues are saturated at 1.0
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"This indicates the normalization bug - raw eigenvalues are not being "
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same value."
)
# Mean should be in reasonable range (not all 1.0)
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
"""
Test that datasets with more variance produce more diverse eigenvalues.
This ensures the normalization preserves relative information.
"""
# Create two preprocessors with different data variance
results = {}
for variance_scale in [0.5, 2.0]:
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(15):
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
# Create varied patterns
for c in range(16):
for h in range(32):
for w in range(32):
latent[c, h, w] = (
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = []
for i in range(15):
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
eigvals.extend(ev[ev > 1e-6])
results[variance_scale] = {
'mean': np.mean(eigvals),
'std': np.std(eigvals),
'range': (np.min(eigvals), np.max(eigvals))
}
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
# Both should have diversity (not saturated)
for scale in [0.5, 2.0]:
assert results[scale]['std'] > 0.1, (
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])