mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Add CDC-FM (Carré du Champ Flow Matching) support
Implements geometry-aware noise generation for FLUX training based on arXiv:2510.05930v1.
This commit is contained in:
@@ -2,10 +2,8 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator, PartialState
|
||||
@@ -183,7 +181,7 @@ def sample_image_inference(
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info("negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
@@ -469,8 +467,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, batch_indices=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps for training.
|
||||
|
||||
Args:
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
@@ -514,6 +520,44 @@ def get_noisy_model_input_and_timesteps(
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and batch_indices is not None:
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps / num_timesteps
|
||||
|
||||
# Process each sample individually to handle potential dimension mismatches
|
||||
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
|
||||
B, C, H, W = noise.shape
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
|
||||
|
||||
# Get cached shape for this sample
|
||||
cached_shape = gamma_b_dataset.get_shape(idx)
|
||||
current_shape = (C, H, W)
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - sample was bucketed differently between preprocessing and training
|
||||
# Use standard Gaussian noise for this sample (no CDC)
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {idx}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
noise_transformed.append(noise[i])
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
|
||||
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
|
||||
|
||||
noise = torch.stack(noise_transformed, dim=0)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
|
||||
Reference in New Issue
Block a user