mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
- Merged redundant test files - Removed 'comprehensive' from file and docstring names - Improved test organization and clarity - Ensured all tests continue to pass - Simplified test documentation
297 lines
11 KiB
Python
297 lines
11 KiB
Python
"""
|
|
CDC Gradient Flow Verification Tests
|
|
|
|
This module provides testing of:
|
|
1. Mock dataset gradient preservation
|
|
2. Real dataset gradient flow
|
|
3. Various time steps and computation paths
|
|
4. Fallback and edge case scenarios
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
|
from library.flux_train_utils import apply_cdc_noise_transformation
|
|
|
|
|
|
class MockGammaBDataset:
|
|
"""
|
|
Mock implementation of GammaBDataset for testing gradient flow
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Simple initialization that doesn't require file loading
|
|
"""
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
def compute_sigma_t_x(
|
|
self,
|
|
eigenvectors: torch.Tensor,
|
|
eigenvalues: torch.Tensor,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Simplified implementation of compute_sigma_t_x for testing
|
|
"""
|
|
# Store original shape to restore later
|
|
orig_shape = x.shape
|
|
|
|
# Flatten x if it's 4D
|
|
if x.dim() == 4:
|
|
B, C, H, W = x.shape
|
|
x = x.reshape(B, -1) # (B, C*H*W)
|
|
|
|
# Validate dimensions
|
|
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
|
|
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
|
|
|
|
# Early return for t=0 with gradient preservation
|
|
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
|
|
return x.reshape(orig_shape)
|
|
|
|
# Compute Σ_t @ x
|
|
# V^T x
|
|
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
|
|
|
# sqrt(λ) * V^T x
|
|
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
|
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
|
|
|
# V @ (sqrt(λ) * V^T x)
|
|
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
|
|
|
# Interpolate between original and noisy latent
|
|
result = (1 - t) * x + t * gamma_sqrt_x
|
|
|
|
# Restore original shape
|
|
result = result.reshape(orig_shape)
|
|
|
|
return result
|
|
|
|
|
|
class TestCDCGradientFlow:
|
|
"""
|
|
Gradient flow testing for CDC noise transformations
|
|
"""
|
|
|
|
def setup_method(self):
|
|
"""Prepare consistent test environment"""
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
def test_mock_gradient_flow_near_zero_time_step(self):
|
|
"""
|
|
Verify gradient flow preservation for near-zero time steps
|
|
using mock dataset with learnable time embeddings
|
|
"""
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create a learnable time embedding with small initial value
|
|
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
|
|
|
|
# Generate mock latent and CDC components
|
|
batch_size, latent_dim = 4, 64
|
|
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
|
|
|
# Create mock eigenvectors and eigenvalues
|
|
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
|
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
|
|
|
# Ensure eigenvectors and eigenvalues are meaningful
|
|
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
|
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
|
|
|
# Use the mock dataset
|
|
mock_dataset = MockGammaBDataset()
|
|
|
|
# Compute noisy latent with gradient tracking
|
|
noisy_latent = mock_dataset.compute_sigma_t_x(
|
|
eigenvectors,
|
|
eigenvalues,
|
|
latent,
|
|
t
|
|
)
|
|
|
|
# Compute a dummy loss to check gradient flow
|
|
loss = noisy_latent.sum()
|
|
|
|
# Compute gradients
|
|
loss.backward()
|
|
|
|
# Assertions to verify gradient flow
|
|
assert t.grad is not None, "Time embedding gradient should be computed"
|
|
assert latent.grad is not None, "Input latent gradient should be computed"
|
|
|
|
# Check gradient magnitudes are non-zero
|
|
t_grad_magnitude = torch.abs(t.grad).sum()
|
|
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
|
|
|
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
|
|
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
|
|
|
|
def test_gradient_flow_with_multiple_time_steps(self):
|
|
"""
|
|
Verify gradient flow across different time step values
|
|
"""
|
|
# Test time steps
|
|
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
|
|
|
|
for time_val in time_steps:
|
|
# Create a learnable time embedding
|
|
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
|
|
|
|
# Generate mock latent and CDC components
|
|
batch_size, latent_dim = 4, 64
|
|
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
|
|
|
|
# Create mock eigenvectors and eigenvalues
|
|
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
|
|
eigenvalues = torch.rand(batch_size, 8, device=self.device)
|
|
|
|
# Ensure eigenvectors and eigenvalues are meaningful
|
|
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
|
|
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
|
|
|
|
# Use the mock dataset
|
|
mock_dataset = MockGammaBDataset()
|
|
|
|
# Compute noisy latent with gradient tracking
|
|
noisy_latent = mock_dataset.compute_sigma_t_x(
|
|
eigenvectors,
|
|
eigenvalues,
|
|
latent,
|
|
t
|
|
)
|
|
|
|
# Compute a dummy loss to check gradient flow
|
|
loss = noisy_latent.sum()
|
|
|
|
# Compute gradients
|
|
loss.backward()
|
|
|
|
# Assertions to verify gradient flow
|
|
t_grad_magnitude = torch.abs(t.grad).sum()
|
|
latent_grad_magnitude = torch.abs(latent.grad).sum()
|
|
|
|
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
|
|
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
|
|
|
|
# Reset gradients for next iteration
|
|
t.grad.zero_() if t.grad is not None else None
|
|
latent.grad.zero_() if latent.grad is not None else None
|
|
|
|
def test_gradient_flow_with_real_dataset(self, tmp_path):
|
|
"""
|
|
Test gradient flow with real CDC dataset
|
|
"""
|
|
# Create cache with uniform shapes
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
|
)
|
|
|
|
shape = (16, 32, 32)
|
|
for i in range(10):
|
|
latent = torch.randn(*shape, dtype=torch.float32)
|
|
metadata = {'image_key': f'test_image_{i}'}
|
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape, metadata=metadata)
|
|
|
|
cache_path = tmp_path / "test_gradient.safetensors"
|
|
preprocessor.compute_all(save_path=cache_path)
|
|
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
|
|
|
# Prepare test noise
|
|
torch.manual_seed(42)
|
|
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
|
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
|
image_keys = ['test_image_0', 'test_image_1', 'test_image_2', 'test_image_3']
|
|
|
|
# Apply CDC transformation
|
|
noise_out = apply_cdc_noise_transformation(
|
|
noise=noise,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Verify gradient flow
|
|
assert noise_out.requires_grad, "Output should require gradients"
|
|
|
|
loss = noise_out.sum()
|
|
loss.backward()
|
|
|
|
assert noise.grad is not None, "Gradients should flow back to input noise"
|
|
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
|
|
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
|
|
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
|
|
|
|
def test_gradient_flow_with_fallback(self, tmp_path):
|
|
"""
|
|
Test gradient flow when using Gaussian fallback (shape mismatch)
|
|
|
|
Ensures that cloned tensors maintain gradient flow correctly
|
|
even when shape mismatch triggers Gaussian noise
|
|
"""
|
|
# Create cache with one shape
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
|
)
|
|
|
|
preprocessed_shape = (16, 32, 32)
|
|
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
|
metadata = {'image_key': 'test_image_0'}
|
|
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape, metadata=metadata)
|
|
|
|
cache_path = tmp_path / "test_fallback_gradient.safetensors"
|
|
preprocessor.compute_all(save_path=cache_path)
|
|
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
|
|
|
# Use different shape at runtime (will trigger fallback)
|
|
runtime_shape = (16, 64, 64)
|
|
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
|
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
|
image_keys = ['test_image_0']
|
|
|
|
# Apply transformation (should fallback to Gaussian for this sample)
|
|
noise_out = apply_cdc_noise_transformation(
|
|
noise=noise,
|
|
timesteps=timesteps,
|
|
num_timesteps=1000,
|
|
gamma_b_dataset=dataset,
|
|
image_keys=image_keys,
|
|
device="cpu"
|
|
)
|
|
|
|
# Ensure gradients still flow through fallback path
|
|
assert noise_out.requires_grad, "Fallback output should require gradients"
|
|
|
|
loss = noise_out.sum()
|
|
loss.backward()
|
|
|
|
assert noise.grad is not None, "Gradients should flow even in fallback case"
|
|
assert not torch.isnan(noise.grad).any(), "Fallback gradients should not contain NaN"
|
|
|
|
|
|
def pytest_configure(config):
|
|
"""
|
|
Configure custom markers for CDC gradient flow tests
|
|
"""
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"mock_dataset: mark test using mock dataset for simplified gradient testing"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"real_dataset: mark test using real dataset for comprehensive gradient testing"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"]) |