mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
- Merged redundant test files - Removed 'comprehensive' from file and docstring names - Improved test organization and clarity - Ensured all tests continue to pass - Simplified test documentation
220 lines
8.8 KiB
Python
220 lines
8.8 KiB
Python
"""
|
||
Comprehensive CDC Eigenvalue Validation Tests
|
||
|
||
These tests ensure that eigenvalue computation and scaling work correctly
|
||
across various scenarios, including:
|
||
- Scaling to reasonable ranges
|
||
- Handling high-dimensional data
|
||
- Preserving latent information
|
||
- Preventing computational artifacts
|
||
"""
|
||
|
||
import numpy as np
|
||
import pytest
|
||
import torch
|
||
from safetensors import safe_open
|
||
|
||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||
|
||
|
||
class TestEigenvalueScaling:
|
||
"""Verify eigenvalue scaling and computational properties"""
|
||
|
||
def test_eigenvalues_in_correct_range(self, tmp_path):
|
||
"""
|
||
Verify eigenvalues are scaled to ~0.01-1.0 range, not millions.
|
||
|
||
Ensures:
|
||
- No numerical explosions
|
||
- Reasonable eigenvalue magnitudes
|
||
- Consistent scaling across samples
|
||
"""
|
||
preprocessor = CDCPreprocessor(
|
||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||
)
|
||
|
||
# Create deterministic latents with structured patterns
|
||
for i in range(10):
|
||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||
for h in range(8):
|
||
for w in range(8):
|
||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||
latent = latent + i * 0.1
|
||
metadata = {'image_key': f'test_image_{i}'}
|
||
|
||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||
|
||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||
result_path = preprocessor.compute_all(save_path=output_path)
|
||
|
||
# Verify eigenvalues are in correct range
|
||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||
all_eigvals = []
|
||
for i in range(10):
|
||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||
all_eigvals.extend(eigvals)
|
||
|
||
all_eigvals = np.array(all_eigvals)
|
||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||
|
||
# Critical assertions for eigenvalue scale
|
||
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
|
||
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
|
||
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
|
||
|
||
# Check sqrt (used in noise) is reasonable
|
||
sqrt_max = np.sqrt(all_eigvals.max())
|
||
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
|
||
|
||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
|
||
print(f"✓ sqrt(max): {sqrt_max:.4f}")
|
||
|
||
def test_high_dimensional_latents_scaling(self, tmp_path):
|
||
"""
|
||
Verify scaling for high-dimensional realistic latents.
|
||
|
||
Key scenarios:
|
||
- High-dimensional data (16×64×64)
|
||
- Varied channel structures
|
||
- Realistic VAE-like data
|
||
"""
|
||
preprocessor = CDCPreprocessor(
|
||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||
)
|
||
|
||
# Create 20 samples with realistic varied structure
|
||
for i in range(20):
|
||
# High-dimensional latent like FLUX
|
||
latent = torch.zeros(16, 64, 64, dtype=torch.float32)
|
||
|
||
# Create varied structure across the latent
|
||
for c in range(16):
|
||
# Different patterns across channels
|
||
if c < 4:
|
||
for h in range(64):
|
||
for w in range(64):
|
||
latent[c, h, w] = (h + w) / 128.0
|
||
elif c < 8:
|
||
for h in range(64):
|
||
for w in range(64):
|
||
latent[c, h, w] = np.sin(h / 10.0) * np.cos(w / 10.0)
|
||
else:
|
||
latent[c, :, :] = c * 0.1
|
||
|
||
# Add per-sample variation
|
||
latent = latent * (1.0 + i * 0.2)
|
||
latent = latent + torch.linspace(-0.5, 0.5, 16).view(16, 1, 1).expand(16, 64, 64) * (i % 3)
|
||
|
||
metadata = {'image_key': f'test_image_{i}'}
|
||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||
|
||
output_path = tmp_path / "test_realistic_gamma_b.safetensors"
|
||
result_path = preprocessor.compute_all(save_path=output_path)
|
||
|
||
# Verify eigenvalues are not all saturated
|
||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||
all_eigvals = []
|
||
for i in range(20):
|
||
eigvals = f.get_tensor(f"eigenvalues/test_image_{i}").numpy()
|
||
all_eigvals.extend(eigvals)
|
||
|
||
all_eigvals = np.array(all_eigvals)
|
||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||
|
||
at_max = np.sum(np.abs(all_eigvals - 1.0) < 0.01)
|
||
total = len(non_zero_eigvals)
|
||
percent_at_max = (at_max / total * 100) if total > 0 else 0
|
||
|
||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||
print(f"✓ Std: {np.std(non_zero_eigvals):.4f}")
|
||
print(f"✓ At max (1.0): {at_max}/{total} ({percent_at_max:.1f}%)")
|
||
|
||
# Fail if too many eigenvalues are saturated
|
||
assert percent_at_max < 80, (
|
||
f"{percent_at_max:.1f}% of eigenvalues are saturated at 1.0! "
|
||
f"Raw eigenvalues not scaled before clamping. "
|
||
f"Range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]"
|
||
)
|
||
|
||
# Should have good diversity
|
||
assert np.std(non_zero_eigvals) > 0.1, (
|
||
f"Eigenvalue std {np.std(non_zero_eigvals):.4f} is too low. "
|
||
f"Should see diverse eigenvalues, not all the same."
|
||
)
|
||
|
||
# Mean should be in reasonable range
|
||
mean_eigval = np.mean(non_zero_eigvals)
|
||
assert 0.05 < mean_eigval < 0.9, (
|
||
f"Mean eigenvalue {mean_eigval:.4f} is outside expected range [0.05, 0.9]. "
|
||
f"If mean ≈ 1.0, eigenvalues are saturated."
|
||
)
|
||
|
||
def test_noise_magnitude_reasonable(self, tmp_path):
|
||
"""
|
||
Verify CDC noise has reasonable magnitude for training.
|
||
|
||
Ensures noise:
|
||
- Has similar scale to input latents
|
||
- Won't destabilize training
|
||
- Preserves input variance
|
||
"""
|
||
preprocessor = CDCPreprocessor(
|
||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||
)
|
||
|
||
for i in range(10):
|
||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||
for c in range(16):
|
||
for h in range(4):
|
||
for w in range(4):
|
||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||
metadata = {'image_key': f'test_image_{i}'}
|
||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape, metadata=metadata)
|
||
|
||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||
|
||
# Load and compute noise
|
||
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||
|
||
# Simulate training scenario with deterministic data
|
||
batch_size = 3
|
||
latents = torch.zeros(batch_size, 16, 4, 4)
|
||
for b in range(batch_size):
|
||
for c in range(16):
|
||
for h in range(4):
|
||
for w in range(4):
|
||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||
image_keys = ['test_image_0', 'test_image_5', 'test_image_9']
|
||
|
||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(image_keys)
|
||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||
|
||
# Check noise magnitude
|
||
noise_std = noise.std().item()
|
||
latent_std = latents.std().item()
|
||
|
||
# Noise should be similar magnitude to input latents (within 10x)
|
||
ratio = noise_std / latent_std
|
||
assert 0.1 < ratio < 10.0, (
|
||
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
|
||
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
|
||
)
|
||
|
||
# Simulated MSE loss should be reasonable
|
||
simulated_loss = torch.mean((noise - latents) ** 2).item()
|
||
assert simulated_loss < 100.0, (
|
||
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
|
||
f"Should be O(0.1-1.0) for stable training."
|
||
)
|
||
|
||
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
|
||
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v", "-s"]) |