From 0d822b2f74b5101ccf3fcb52384a420bd9d20638 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Oct 2025 15:30:41 -0400 Subject: [PATCH] Refactor: Extract CDC noise transformation to separate function - Create apply_cdc_noise_transformation() for better modularity - Implement fast path for batch processing when all shapes match - Implement slow path for per-sample processing on shape mismatch - Clone noise tensors in fallback path for gradient consistency --- .gitignore | 1 + library/flux_train_utils.py | 113 +++++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index cfdc0268..a3272cc4 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ GEMINI.md .claude .gemini MagicMock +benchmark_*.py diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index b40a1654..98c41d71 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -466,6 +466,76 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting +def apply_cdc_noise_transformation( + noise: torch.Tensor, + timesteps: torch.Tensor, + num_timesteps: int, + gamma_b_dataset, + batch_indices, + device +) -> torch.Tensor: + """ + Apply CDC-FM geometry-aware noise transformation. + + Args: + noise: (B, C, H, W) standard Gaussian noise + timesteps: (B,) timesteps for this batch + num_timesteps: Total number of timesteps in scheduler + gamma_b_dataset: GammaBDataset with cached CDC matrices + batch_indices: (B,) global dataset indices for this batch + device: Device to load CDC matrices to + + Returns: + Transformed noise with geometry-aware covariance + """ + # Normalize timesteps to [0, 1] for CDC-FM + t_normalized = timesteps / num_timesteps + + B, C, H, W = noise.shape + current_shape = (C, H, W) + + # Fast path: Check if all samples have matching shapes (common case) + # This avoids per-sample processing when bucketing is consistent + indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)] + cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list] + + all_match = all(s == current_shape for s in cached_shapes) + + if all_match: + # Batch processing: All shapes match, process entire batch at once + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device) + noise_flat = noise.reshape(B, -1) + noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) + return noise_cdc_flat.reshape(B, C, H, W) + else: + # Slow path: Some shapes mismatch, process individually + noise_transformed = [] + + for i in range(B): + idx = indices_list[i] + cached_shape = cached_shapes[i] + + if cached_shape != current_shape: + # Shape mismatch - use standard Gaussian noise for this sample + logger.warning( + f"CDC shape mismatch for sample {idx}: " + f"cached {cached_shape} vs current {current_shape}. " + f"Using Gaussian noise (no CDC)." + ) + noise_transformed.append(noise[i].clone()) + else: + # Shapes match - apply CDC transformation + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device) + + noise_flat = noise[i].reshape(1, -1) + t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized + + noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single) + noise_transformed.append(noise_cdc_flat.reshape(C, H, W)) + + return torch.stack(noise_transformed, dim=0) + + def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype, gamma_b_dataset=None, batch_indices=None @@ -522,41 +592,14 @@ def get_noisy_model_input_and_timesteps( # Apply CDC-FM geometry-aware noise transformation if enabled if gamma_b_dataset is not None and batch_indices is not None: - # Normalize timesteps to [0, 1] for CDC-FM - t_normalized = timesteps / num_timesteps - - # Process each sample individually to handle potential dimension mismatches - # (can happen with multi-subset training where bucketing differs between preprocessing and training) - B, C, H, W = noise.shape - noise_transformed = [] - - for i in range(B): - idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] - - # Get cached shape for this sample - cached_shape = gamma_b_dataset.get_shape(idx) - current_shape = (C, H, W) - - if cached_shape != current_shape: - # Shape mismatch - sample was bucketed differently between preprocessing and training - # Use standard Gaussian noise for this sample (no CDC) - logger.warning( - f"CDC shape mismatch for sample {idx}: " - f"cached {cached_shape} vs current {current_shape}. " - f"Using Gaussian noise (no CDC)." - ) - noise_transformed.append(noise[i]) - else: - # Shapes match - apply CDC transformation - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device) - - noise_flat = noise[i].reshape(1, -1) - t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized - - noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single) - noise_transformed.append(noise_cdc_flat.reshape(C, H, W)) - - noise = torch.stack(noise_transformed, dim=0) + noise = apply_cdc_noise_transformation( + noise=noise, + timesteps=timesteps, + num_timesteps=num_timesteps, + gamma_b_dataset=gamma_b_dataset, + batch_indices=batch_indices, + device=device + ) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process)