Fix: Enable gradient flow through CDC noise transformation

- Remove @torch.no_grad() decorator from compute_sigma_t_x()
- Gradients now properly flow through CDC transformation during training
- Add comprehensive gradient flow tests for fast/slow paths and fallback
- All 25 CDC tests passing
This commit is contained in:
rockerBOO
2025-10-09 15:35:00 -04:00
parent 0d822b2f74
commit 88af20881d
2 changed files with 202 additions and 1 deletions

View File

@@ -655,7 +655,6 @@ class GammaBDataset:
"""Get the original shape for a sample (cached in memory)"""
return self.shapes_cache[idx]
@torch.no_grad()
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
@@ -674,6 +673,9 @@ class GammaBDataset:
Returns:
result: Same shape as input x
Note:
Gradients flow through this function for backprop during training.
"""
# Store original shape to restore later
orig_shape = x.shape

View File

@@ -0,0 +1,199 @@
"""
Test gradient flow through CDC noise transformation.
Ensures that gradients propagate correctly through both fast and slow paths.
"""
import pytest
import torch
import tempfile
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestCDCGradientFlow:
"""Test gradient flow through CDC transformations"""
@pytest.fixture
def cdc_cache(self, tmp_path):
"""Create a test CDC cache"""
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Create samples with same shape for fast path testing
shape = (16, 32, 32)
for i in range(20):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
cache_path = tmp_path / "test_gradient.safetensors"
preprocessor.compute_all(save_path=cache_path)
return cache_path
def test_gradient_flow_fast_path(self, cdc_cache):
"""
Test that gradients flow correctly through batch processing (fast path).
All samples have matching shapes, so CDC uses batch processing.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
batch_size = 4
shape = (16, 32, 32)
# Create input noise with requires_grad
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
# Apply CDC transformation
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Ensure output requires grad
assert noise_out.requires_grad, "Output should require gradients"
# Compute a simple loss and backprop
loss = noise_out.sum()
loss.backward()
# Verify gradients were computed for input
assert noise.grad is not None, "Gradients should flow back to input noise"
assert not torch.isnan(noise.grad).any(), "Gradients should not contain NaN"
assert not torch.isinf(noise.grad).any(), "Gradients should not contain inf"
assert (noise.grad != 0).any(), "Gradients should not be all zeros"
def test_gradient_flow_slow_path_all_match(self, cdc_cache):
"""
Test gradient flow when slow path is taken but all shapes match.
This tests the per-sample loop with CDC transformation.
"""
dataset = GammaBDataset(gamma_b_path=cdc_cache, device="cpu")
batch_size = 4
shape = (16, 32, 32)
noise = torch.randn(batch_size, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
# Apply transformation
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Test gradient flow
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None
assert not torch.isnan(noise.grad).any()
assert (noise.grad != 0).any()
def test_gradient_consistency_between_paths(self, tmp_path):
"""
Test that fast path and slow path produce similar gradients.
When all shapes match, both paths should give consistent results.
"""
# Create cache with uniform shapes
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
shape = (16, 32, 32)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
cache_path = tmp_path / "test_consistency.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Same input for both tests
torch.manual_seed(42)
noise = torch.randn(4, *shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0, 200.0, 300.0, 400.0], dtype=torch.float32)
batch_indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
# Apply CDC (should use fast path)
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Compute gradients
loss = noise_out.sum()
loss.backward()
# Both paths should produce valid gradients
assert noise.grad is not None
assert not torch.isnan(noise.grad).any()
def test_fallback_gradient_flow(self, tmp_path):
"""
Test gradient flow when using Gaussian fallback (shape mismatch).
Ensures that cloned tensors maintain gradient flow correctly.
"""
# Create cache with one shape
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
preprocessed_shape = (16, 32, 32)
latent = torch.randn(*preprocessed_shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=0, shape=preprocessed_shape)
cache_path = tmp_path / "test_fallback.safetensors"
preprocessor.compute_all(save_path=cache_path)
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
# Use different shape at runtime (will trigger fallback)
runtime_shape = (16, 64, 64)
noise = torch.randn(1, *runtime_shape, dtype=torch.float32, requires_grad=True)
timesteps = torch.tensor([100.0], dtype=torch.float32)
batch_indices = torch.tensor([0], dtype=torch.long)
# Apply transformation (should fallback to Gaussian for this sample)
# Note: This will log a warning but won't raise
noise_out = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
batch_indices=batch_indices,
device="cpu"
)
# Ensure gradients still flow through fallback path
assert noise_out.requires_grad, "Fallback output should require gradients"
loss = noise_out.sum()
loss.backward()
assert noise.grad is not None, "Gradients should flow even in fallback case"
assert not torch.isnan(noise.grad).any()
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])