Add CDC-FM (Carré du Champ Flow Matching) support

Implements geometry-aware noise generation for FLUX training based on
arXiv:2510.05930v1.
This commit is contained in:
rockerBOO
2025-10-09 15:18:43 -04:00
parent 5e366acda4
commit f552f9a3bd
8 changed files with 1615 additions and 13 deletions

View File

@@ -2,10 +2,8 @@ import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
from accelerate import Accelerator, PartialState
@@ -183,7 +181,7 @@ def sample_image_inference(
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info("negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
@@ -469,8 +467,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, batch_indices=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Args:
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
@@ -514,6 +520,44 @@ def get_noisy_model_input_and_timesteps(
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and batch_indices is not None:
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps / num_timesteps
# Process each sample individually to handle potential dimension mismatches
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
B, C, H, W = noise.shape
noise_transformed = []
for i in range(B):
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
# Get cached shape for this sample
cached_shape = gamma_b_dataset.get_shape(idx)
current_shape = (C, H, W)
if cached_shape != current_shape:
# Shape mismatch - sample was bucketed differently between preprocessing and training
# Use standard Gaussian noise for this sample (no CDC)
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
noise_transformed.append(noise[i])
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
noise = torch.stack(noise_transformed, dim=0)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma: