mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 17:35:47 +00:00
Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes - Simplify test imports by removing unused modules - Minor formatting improvements in print statements - Ensure log messages provide clear context about dimension mismatches
This commit is contained in:
164
tests/library/test_cdc_eigenvalue_real_data.py
Normal file
164
tests/library/test_cdc_eigenvalue_real_data.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Tests using realistic high-dimensional data to catch scaling bugs.
|
||||
|
||||
This test uses realistic VAE-like latents to ensure eigenvalue normalization
|
||||
works correctly on real-world data.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestRealisticDataScaling:
|
||||
"""Test eigenvalue scaling with realistic high-dimensional data"""
|
||||
|
||||
def test_high_dimensional_latents_not_saturated(self, tmp_path):
|
||||
"""
|
||||
Verify that high-dimensional realistic latents don't saturate eigenvalues.
|
||||
|
||||
This test simulates real FLUX training data:
|
||||
- High dimension (16×64×64 = 65536)
|
||||
- Varied content (different variance in different regions)
|
||||
- Realistic magnitude (VAE output scale)
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create 20 samples with realistic varied structure
|
||||
for i in range(20):
|
||||
# High-dimensional latent like FLUX
|
||||
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
|
||||
|
||||
# Create varied structure across the latent
|
||||
# Different channels have different patterns (realistic for VAE)
|
||||
for c in range(16):
|
||||
# Some channels have gradients
|
||||
if c < 4:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = (h + w) / 128.0
|
||||
# Some channels have patterns
|
||||
elif c < 8:
|
||||
for h in range(64):
|
||||
for w in range(64):
|
||||
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
|
||||
# Some channels are more uniform
|
||||
else:
|
||||
latent[c, :, :] = c * 0.1
|
||||
|
||||
# Add per-sample variation (different "subjects")
|
||||
latent = latent * (1.0 + i * 0.2)
|
||||
|
||||
# Add realistic VAE-like noise/variation
|
||||
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are NOT all saturated at 1.0
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(20):
|
||||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical: eigenvalues should NOT all be 1.0
|
||||
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
|
||||
total = len(non_zero_eigvals)
|
||||
percent_at_max = (at_max / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
|
||||
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
|
||||
|
||||
# FAIL if too many eigenvalues are saturated at 1.0
|
||||
assert percent_at_max < 80, (
|
||||
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
|
||||
f"This indicates the normalization bug - raw eigenvalues are not being "
|
||||
f"scaled before clamping. Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
|
||||
)
|
||||
|
||||
# Should have good diversity
|
||||
assert np.std(non_zero_eigvals) > 0.1, (
|
||||
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
|
||||
f"Should see diverse eigenvalues, not all the same value."
|
||||
)
|
||||
|
||||
# Mean should be in reasonable range (not all 1.0)
|
||||
mean_eigval = np.mean(non_zero_eigvals)
|
||||
assert 0.05 < mean_eigval < 0.9, (
|
||||
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
|
||||
f"If mean ≈ 1.0, eigenvalues are saturated."
|
||||
)
|
||||
|
||||
def test_eigenvalue_diversity_scales_with_data_variance(self, tmp_path):
|
||||
"""
|
||||
Test that datasets with more variance produce more diverse eigenvalues.
|
||||
|
||||
This ensures the normalization preserves relative information.
|
||||
"""
|
||||
# Create two preprocessors with different data variance
|
||||
results = {}
|
||||
|
||||
for variance_scale in [0.5, 2.0]:
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(15):
|
||||
latent = torch.zeros(16, 32, 32, dtype=torch.float32)
|
||||
|
||||
# Create varied patterns
|
||||
for c in range(16):
|
||||
for h in range(32):
|
||||
for w in range(32):
|
||||
latent[c, h, w] = (
|
||||
np.sin(h / 5.0 + i) * np.cos(w / 5.0 + c) * variance_scale
|
||||
)
|
||||
|
||||
metadata = {'image_key': f'test_image_{i}'}
|
||||
|
||||
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||||
|
||||
output_path = tmp_path / f"test_variance_{variance_scale}.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(output_path), framework="pt", device="cpu") as f:
|
||||
eigvals = []
|
||||
for i in range(15):
|
||||
ev = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||||
eigvals.extend(ev[ev > 1e-6])
|
||||
|
||||
results[variance_scale] = {
|
||||
'mean': np.mean(eigvals),
|
||||
'std': np.std(eigvals),
|
||||
'range': (np.min(eigvals), np.max(eigvals))
|
||||
}
|
||||
|
||||
print(f"\n✓ Low variance data: mean={results[0.5]['mean']:.4f}, std={results[0.5]['std']:.4f}")
|
||||
print(f"✓ High variance data: mean={results[2.0]['mean']:.4f}, std={results[2.0]['std']:.4f}")
|
||||
|
||||
# Both should have diversity (not saturated)
|
||||
for scale in [0.5, 2.0]:
|
||||
assert results[scale]['std'] > 0.1, (
|
||||
f"Variance scale {scale} has too low std: {results[scale]['std']:.4f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user