Files
Kohya-ss-sd-scripts/tests/library/test_cdc_gradient_flow.py
rockerBOO 1f79115c6c Consolidate and simplify CDC test files
- Merged redundant test files
- Removed 'comprehensive' from file and docstring names
- Improved test organization and clarity
- Ensured all tests continue to pass
- Simplified test documentation
2025-10-11 17:48:08 -04:00

297 lines
11 KiB
Python

"""
CDC Gradient Flow Verification Tests
This module provides testing of:
1. Mock dataset gradient preservation
2. Real dataset gradient flow
3. Various time steps and computation paths
4. Fallback and edge case scenarios
"""
import pytest
import torch
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: torch.Tensor
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCGradientFlow:
"""
Gradient flow testing for CDC noise transformations
"""
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_mock_gradient_flow_near_zero_time_step(self):
"""
Verify gradient flow preservation for near-zero time steps
using mock dataset with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
def test_gradient_flow_with_multiple_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
t.grad.zero_() if t.grad is not None else None
latent.grad.zero_() if latent.grad is not None else None
def test_gradient_flow_with_real_dataset(self, tmp_path):
"""
Test gradient flow with real CDC dataset
"""
# Create cache with uniform shapes
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
cache_path = tmp_path / "test_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Prepare test noise
torch.manual_seed(42)
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
# Apply CDC transformation
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Verify gradient flow
assert noise_out.requires_grad, "Output should require gradients"
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None, "Gradients should flow back to input noise"
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
def test_gradient_flow_with_fallback(self, tmp_path):
"""
Test gradient flow when using Gaussian fallback (shape mismatch)
Ensures that cloned tensors maintain gradient flow correctly
even when shape mismatch triggers Gaussian noise
"""
# Create cache with one shape
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
preprocessed_shape = (16, 32, 32)
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
metadata = {'image_key': 'test_image_0'}
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
cache_path = tmp_path / "test_fallback_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Use different shape at runtime (will trigger fallback)
runtime_shape = (16, 64, 64)
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0], dtype=torch.float32)
image_keys = ['test_image_0']
# Apply transformation (should fallback to Gaussian for this sample)
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
image_keys=image_keys,
device="cpu"
)
# Ensure gradients still flow through fallback path
assert noise_out.requires_grad, "Fallback output should require gradients"
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None, "Gradients should flow even in fallback case"
assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN"
def pytest_configure(config):
"""
Configure custom markers for CDC gradient flow tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)
config.addinivalue_line(
"markers",
"mock_dataset: mark test using mock dataset for simplified gradient testing"
)
config.addinivalue_line(
"markers",
"real_dataset: mark test using real dataset for comprehensive gradient testing"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])