Refactor: Extract CDC noise transformation to separate function

- Create apply_cdc_noise_transformation() for better modularity
- Implement fast path for batch processing when all shapes match
- Implement slow path for per-sample processing on shape mismatch
- Clone noise tensors in fallback path for gradient consistency
This commit is contained in:
rockerBOO
2025-10-09 15:30:41 -04:00
parent e03200bdba
commit 0d822b2f74
2 changed files with 79 additions and 35 deletions

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
benchmark_*.py

View File

@@ -466,6 +466,76 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def apply_cdc_noise_transformation(
noise: torch.Tensor,
timesteps: torch.Tensor,
num_timesteps: int,
gamma_b_dataset,
batch_indices,
device
) -> torch.Tensor:
"""
Apply CDC-FM geometry-aware noise transformation.
Args:
noise: (B, C, H, W) standard Gaussian noise
timesteps: (B,) timesteps for this batch
num_timesteps: Total number of timesteps in scheduler
gamma_b_dataset: GammaBDataset with cached CDC matrices
batch_indices: (B,) global dataset indices for this batch
device: Device to load CDC matrices to
Returns:
Transformed noise with geometry-aware covariance
"""
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps / num_timesteps
B, C, H, W = noise.shape
current_shape = (C, H, W)
# Fast path: Check if all samples have matching shapes (common case)
# This avoids per-sample processing when bucketing is consistent
indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)]
cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list]
all_match = all(s == current_shape for s in cached_shapes)
if all_match:
# Batch processing: All shapes match, process entire batch at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
else:
# Slow path: Some shapes mismatch, process individually
noise_transformed = []
for i in range(B):
idx = indices_list[i]
cached_shape = cached_shapes[i]
if cached_shape != current_shape:
# Shape mismatch - use standard Gaussian noise for this sample
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
noise_transformed.append(noise[i].clone())
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
return torch.stack(noise_transformed, dim=0)
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, batch_indices=None
@@ -522,41 +592,14 @@ def get_noisy_model_input_and_timesteps(
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and batch_indices is not None:
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps / num_timesteps
# Process each sample individually to handle potential dimension mismatches
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
B, C, H, W = noise.shape
noise_transformed = []
for i in range(B):
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
# Get cached shape for this sample
cached_shape = gamma_b_dataset.get_shape(idx)
current_shape = (C, H, W)
if cached_shape != current_shape:
# Shape mismatch - sample was bucketed differently between preprocessing and training
# Use standard Gaussian noise for this sample (no CDC)
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
noise_transformed.append(noise[i])
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
noise = torch.stack(noise_transformed, dim=0)
noise = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=num_timesteps,
gamma_b_dataset=gamma_b_dataset,
batch_indices=batch_indices,
device=device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)