mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix: Enable gradient flow through CDC noise transformation
- Remove @torch.no_grad() decorator from compute_sigma_t_x() - Gradients now properly flow through CDC transformation during training - Add comprehensive gradient flow tests for fast/slow paths and fallback - All 25 CDC tests passing
This commit is contained in:
@@ -655,7 +655,6 @@ class GammaBDataset:
|
||||
"""Get the original shape for a sample (cached in memory)"""
|
||||
return self.shapes_cache[idx]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
@@ -674,6 +673,9 @@ class GammaBDataset:
|
||||
|
||||
Returns:
|
||||
result: Same shape as input x
|
||||
|
||||
Note:
|
||||
Gradients flow through this function for backprop during training.
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
199
tests/library/test_cdc_gradient_flow.py
Normal file
199
tests/library/test_cdc_gradient_flow.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Test gradient flow through CDC noise transformation.
|
||||
|
||||
Ensures that gradients propagate correctly through both fast and slow paths.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
from library.flux_train_utils import apply_cdc_noise_transformation
|
||||
|
||||
|
||||
class TestCDCGradientFlow:
|
||||
"""Test gradient flow through CDC transformations"""
|
||||
|
||||
@pytest.fixture
|
||||
def cdc_cache(self, tmp_path):
|
||||
"""Create a test CDC cache"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Create samples with same shape for fast path testing
|
||||
shape = (16, 32, 32)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
|
||||
cache_path = tmp_path / "test_gradient.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
return cache_path
|
||||
|
||||
def test_gradient_flow_fast_path(self, cdc_cache):
|
||||
"""
|
||||
Test that gradients flow correctly through batch processing (fast path).
|
||||
|
||||
All samples have matching shapes, so CDC uses batch processing.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
batch_size = 4
|
||||
shape = (16, 32, 32)
|
||||
|
||||
# Create input noise with requires_grad
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
|
||||
# Apply CDC transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Ensure output requires grad
|
||||
assert noise_out.requires_grad, "Output should require gradients"
|
||||
|
||||
# Compute a simple loss and backprop
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed for input
|
||||
assert noise.grad is not None, "Gradients should flow back to input noise"
|
||||
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
|
||||
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
|
||||
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
|
||||
|
||||
def test_gradient_flow_slow_path_all_match(self, cdc_cache):
|
||||
"""
|
||||
Test gradient flow when slow path is taken but all shapes match.
|
||||
|
||||
This tests the per-sample loop with CDC transformation.
|
||||
"""
|
||||
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
|
||||
|
||||
batch_size = 4
|
||||
shape = (16, 32, 32)
|
||||
|
||||
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
|
||||
# Apply transformation
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Test gradient flow
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert noise.grad is not None
|
||||
assert not torch.isnan(noise.grad).any()
|
||||
assert (noise.grad != 0).any()
|
||||
|
||||
def test_gradient_consistency_between_paths(self, tmp_path):
|
||||
"""
|
||||
Test that fast path and slow path produce similar gradients.
|
||||
|
||||
When all shapes match, both paths should give consistent results.
|
||||
"""
|
||||
# Create cache with uniform shapes
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
shape = (16, 32, 32)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
|
||||
cache_path = tmp_path / "test_consistency.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
|
||||
# Same input for both tests
|
||||
torch.manual_seed(42)
|
||||
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
|
||||
|
||||
# Apply CDC (should use fast path)
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Compute gradients
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Both paths should produce valid gradients
|
||||
assert noise.grad is not None
|
||||
assert not torch.isnan(noise.grad).any()
|
||||
|
||||
def test_fallback_gradient_flow(self, tmp_path):
|
||||
"""
|
||||
Test gradient flow when using Gaussian fallback (shape mismatch).
|
||||
|
||||
Ensures that cloned tensors maintain gradient flow correctly.
|
||||
"""
|
||||
# Create cache with one shape
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
preprocessed_shape = (16, 32, 32)
|
||||
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape)
|
||||
|
||||
cache_path = tmp_path / "test_fallback.safetensors"
|
||||
preprocessor.compute_all(save_path=cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
|
||||
# Use different shape at runtime (will trigger fallback)
|
||||
runtime_shape = (16, 64, 64)
|
||||
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
|
||||
timesteps = torch.tensor([100.0], dtype=torch.float32)
|
||||
batch_indices = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Apply transformation (should fallback to Gaussian for this sample)
|
||||
# Note: This will log a warning but won't raise
|
||||
noise_out = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=1000,
|
||||
gamma_b_dataset=dataset,
|
||||
batch_indices=batch_indices,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Ensure gradients still flow through fallback path
|
||||
assert noise_out.requires_grad, "Fallback output should require gradients"
|
||||
|
||||
loss = noise_out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert noise.grad is not None, "Gradients should flow even in fallback case"
|
||||
assert not torch.isnan(noise.grad).any()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user