Files
Kohya-ss-sd-scripts/tests/library/test_cdc_eigenvalue_real_data.py
rockerBOO 8089cb6925 Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
2025-10-11 17:17:09 -04:00

165 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Tests using realistic high-dimensional data to catch scaling bugs.
This test uses realistic VAE-like latents to ensure eigenvalue normalization
works correctly on real-world data.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestRealisticDataScaling:
"""Test eigenvalue scaling with realistic high-dimensional data"""
def test_high_dimensional_latents_not_saturated(self, tmp_path):
"""
Verify that high-dimensional realistic latents don't saturate eigenvalues.
This test simulates real FLUX training data:
- High dimension (16×64×64 = 65536)
- Varied content (different variance in different regions)
- Realistic magnitude (VAE output scale)
"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create 20 samples with realistic varied structure
for i in range(20):
# High-dimensional latent like FLUX
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
# Create varied structure across the latent
# Different channels have different patterns (realistic for VAE)
for c in range(16):
# Some channels have gradients
if c < 4:
for h in range(64):
for w in range(64):
latent[c, h, w] = (h + w) / 128.0
# Some channels have patterns
elif c < 8:
for h in range(64):
for w in range(64):
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
# Some channels are more uniform
else:
latent[c, :, :] = c * 0.1
# Add per-sample variation (different "subjects")
latent = latent * (1.0 + i * 0.2)
# Add realistic VAE-like noise/variation
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are NOT all saturated at 1.0
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(20):
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical: eigenvalues should NOT all be 1.0
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
total = len(non_zero_eigvals)
percent_at_max = (at_max / total * 100) if total > 0 else 0
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
# FAIL if too many eigenvalues are saturated at 1.0
assert percent_at_max < 80, (
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
f"This indicates the normalization bug - raw eigenvalues are not being "
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
)
# Should have good diversity
assert np.std(non_zero_eigvals) > 0.1, (
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
f"Should see diverse eigenvalues, not all the same value."
)
# Mean should be in reasonable range (not all 1.0)
mean_eigval = np.mean(non_zero_eigvals)
assert 0.05 < mean_eigval < 0.9, (
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
f"If mean ≈ 1.0, eigenvalues are saturated."
)
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
"""
Test that datasets with more variance produce more diverse eigenvalues.
This ensures the normalization preserves relative information.
"""
# Create two preprocessors with different data variance
results = {}
for variance_scale in [0.5, 2.0]:
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(15):
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
# Create varied patterns
for c in range(16):
for h in range(32):
for w in range(32):
latent[c, h, w] = (
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
preprocessor.compute_all(save_path=output_path)
with safe_open(str(output_path), framework="pt", device="cpu") as f:
eigvals = []
for i in range(15):
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
eigvals.extend(ev[ev > 1e-6])
results[variance_scale] = {
'mean': np.mean(eigvals),
'std': np.std(eigvals),
'range': (np.min(eigvals), np.max(eigvals))
}
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
# Both should have diversity (not saturated)
for scale in [0.5, 2.0]:
assert results[scale]['std'] > 0.1, (
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])