mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Refactor: Extract CDC noise transformation to separate function
- Create apply_cdc_noise_transformation() for better modularity - Implement fast path for batch processing when all shapes match - Implement slow path for per-sample processing on shape mismatch - Clone noise tensors in fallback path for gradient consistency
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
benchmark_*.py
|
||||
|
||||
@@ -466,6 +466,76 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def apply_cdc_noise_transformation(
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
num_timesteps: int,
|
||||
gamma_b_dataset,
|
||||
batch_indices,
|
||||
device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply CDC-FM geometry-aware noise transformation.
|
||||
|
||||
Args:
|
||||
noise: (B, C, H, W) standard Gaussian noise
|
||||
timesteps: (B,) timesteps for this batch
|
||||
num_timesteps: Total number of timesteps in scheduler
|
||||
gamma_b_dataset: GammaBDataset with cached CDC matrices
|
||||
batch_indices: (B,) global dataset indices for this batch
|
||||
device: Device to load CDC matrices to
|
||||
|
||||
Returns:
|
||||
Transformed noise with geometry-aware covariance
|
||||
"""
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps / num_timesteps
|
||||
|
||||
B, C, H, W = noise.shape
|
||||
current_shape = (C, H, W)
|
||||
|
||||
# Fast path: Check if all samples have matching shapes (common case)
|
||||
# This avoids per-sample processing when bucketing is consistent
|
||||
indices_list = [batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] for i in range(B)]
|
||||
cached_shapes = [gamma_b_dataset.get_shape(idx) for idx in indices_list]
|
||||
|
||||
all_match = all(s == current_shape for s in cached_shapes)
|
||||
|
||||
if all_match:
|
||||
# Batch processing: All shapes match, process entire batch at once
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(indices_list, device=device)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
else:
|
||||
# Slow path: Some shapes mismatch, process individually
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
idx = indices_list[i]
|
||||
cached_shape = cached_shapes[i]
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - use standard Gaussian noise for this sample
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {idx}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
noise_transformed.append(noise[i].clone())
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
|
||||
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
|
||||
|
||||
return torch.stack(noise_transformed, dim=0)
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, batch_indices=None
|
||||
@@ -522,41 +592,14 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and batch_indices is not None:
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps / num_timesteps
|
||||
|
||||
# Process each sample individually to handle potential dimension mismatches
|
||||
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
|
||||
B, C, H, W = noise.shape
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
|
||||
|
||||
# Get cached shape for this sample
|
||||
cached_shape = gamma_b_dataset.get_shape(idx)
|
||||
current_shape = (C, H, W)
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - sample was bucketed differently between preprocessing and training
|
||||
# Use standard Gaussian noise for this sample (no CDC)
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {idx}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
noise_transformed.append(noise[i])
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
|
||||
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
|
||||
|
||||
noise = torch.stack(noise_transformed, dim=0)
|
||||
noise = apply_cdc_noise_transformation(
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
num_timesteps=num_timesteps,
|
||||
gamma_b_dataset=gamma_b_dataset,
|
||||
batch_indices=batch_indices,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
Reference in New Issue
Block a user