Add CDC-FM (Carré du Champ Flow Matching) support

Implements geometry-aware noise generation for FLUX training based on
arXiv:2510.05930v1.
This commit is contained in:
rockerBOO
2025-10-09 15:18:43 -04:00
parent 5e366acda4
commit f552f9a3bd
8 changed files with 1615 additions and 13 deletions

View File

@@ -1,7 +1,5 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
@@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
def assert_extra_args(
self,
@@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
# Get CDC parameters if enabled
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "indices" in batch) else None
batch_indices = batch.get("indices") if gamma_b_dataset is not None else None
# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, batch_indices=batch_indices
)
# pack latents and get img_ids
@@ -494,7 +499,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
logger.info("T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
@@ -533,6 +538,49 @@ def setup_parser() -> argparse.ArgumentParser:
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
# CDC-FM arguments
parser.add_argument(
"--use_cdc_fm",
action="store_true",
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
" / CDC-FMCarré du Champ Flow Matchingを有効にして幾何学的イズを使用",
)
parser.add_argument(
"--cdc_k_neighbors",
type=int,
default=256,
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
" / CDC-FMのk-NNグラフの近傍数デフォルト: 256",
)
parser.add_argument(
"--cdc_k_bandwidth",
type=int,
default=8,
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
" / CDC-FMの帯域幅推定の近傍数デフォルト: 8",
)
parser.add_argument(
"--cdc_d_cdc",
type=int,
default=8,
help="Dimension of CDC subspace (default: 8)"
" / CDCサブ空間の次元デフォルト: 8",
)
parser.add_argument(
"--cdc_gamma",
type=float,
default=1.0,
help="CDC strength parameter (default: 1.0)"
" / CDC強度パラメータデフォルト: 1.0",
)
parser.add_argument(
"--force_recache_cdc",
action="store_true",
help="Force recompute CDC cache even if valid cache exists"
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
)
return parser

712
library/cdc_fm.py Normal file
View File

@@ -0,0 +1,712 @@
import logging
import torch
import numpy as np
import faiss # type: ignore
from pathlib import Path
from tqdm import tqdm
from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class LatentSample:
"""
Container for a single latent with metadata
"""
latent: np.ndarray # (d,) flattened latent vector
global_idx: int # Global index in dataset
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
class CarreDuChampComputer:
"""
Core CDC-FM computation - agnostic to data source
Just handles the math for a batch of same-size latents
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda'
):
self.k = k_neighbors
self.k_bw = k_bandwidth
self.d_cdc = d_cdc
self.gamma = gamma
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using FAISS
Args:
latents_np: (N, d) numpy array of same-dimensional latents
Returns:
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
indices: (N, k_actual+1) neighbor indices
"""
N, d = latents_np.shape
# Clamp k to available neighbors (can't have more neighbors than samples)
k_actual = min(self.k, N - 1)
# Ensure float32
if latents_np.dtype != np.float32:
latents_np = latents_np.astype(np.float32)
# Build FAISS index
index = faiss.IndexFlatL2(d)
if torch.cuda.is_available():
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
index.add(latents_np) # type: ignore
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
return distances, indices
@torch.no_grad()
def compute_gamma_b_single(
self,
point_idx: int,
latents_np: np.ndarray,
distances: np.ndarray,
indices: np.ndarray,
epsilon: np.ndarray
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute Γ_b for a single point
Args:
point_idx: Index of point to process
latents_np: (N, d) all latents in this batch
distances: (N, k+1) precomputed distances
indices: (N, k+1) precomputed neighbor indices
epsilon: (N,) bandwidth per point
Returns:
eigenvectors: (d_cdc, d) as half precision tensor
eigenvalues: (d_cdc,) as half precision tensor
"""
d = latents_np.shape[1]
# Get neighbors (exclude self)
neighbor_idx = indices[point_idx, 1:] # (k,)
neighbor_points = latents_np[neighbor_idx] # (k, d)
# Clamp distances to prevent overflow (max realistic L2 distance)
MAX_DISTANCE = 1e10
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
# Compute Gaussian kernel weights with numerical guards
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
# Compute denominator with guard against overflow
denom = eps_i * eps_neighbors
denom = np.maximum(denom, 1e-20) # Additional guard
# Compute weights with safe exponential
exp_arg = -neighbor_dists_sq / denom
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
weights = np.exp(exp_arg)
# Normalize weights, handle edge case of all zeros
weight_sum = weights.sum()
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
# Fallback to uniform weights
weights = np.ones_like(weights) / len(weights)
else:
weights = weights / weight_sum
# Compute local mean
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
# Center and weight for SVD
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
# Validate input is finite before SVD
if not np.all(np.isfinite(weighted_centered)):
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
# Fallback: use uniform weights and simple centering
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
m_star = np.mean(neighbor_points, axis=0)
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
# Move to GPU for SVD (100x speedup!)
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
self.device, dtype=torch.float32
)
try:
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
except RuntimeError as e:
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
try:
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
U = torch.from_numpy(U).to(self.device)
S = torch.from_numpy(S).to(self.device)
Vh = torch.from_numpy(Vh).to(self.device)
except np.linalg.LinAlgError as e2:
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
# Return zero eigenvalues/vectors as fallback
return (
torch.zeros(self.d_cdc, d, dtype=torch.float16),
torch.zeros(self.d_cdc, dtype=torch.float16)
)
# Eigenvalues of Γ_b
eigenvalues_full = S ** 2
# Keep top d_cdc
if len(eigenvalues_full) >= self.d_cdc:
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
else:
# Pad if k < d_cdc
top_eigenvalues = eigenvalues_full
top_eigenvectors = Vh
if len(eigenvalues_full) < self.d_cdc:
pad_size = self.d_cdc - len(eigenvalues_full)
top_eigenvalues = torch.cat([
top_eigenvalues,
torch.zeros(pad_size, device=self.device)
])
top_eigenvectors = torch.cat([
top_eigenvectors,
torch.zeros(pad_size, d, device=self.device)
])
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
# Then apply gamma: γc_i Γ̂(x^(i))
#
# Our implementation:
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
# 3. Clamp for numerical stability
#
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
# Step 1: Normalize by the maximum eigenvalue to get relative scales
# This is the paper's 1/λ_1^i normalization factor
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
if max_eigenval > 1e-10:
# Scale so max eigenvalue = 1.0, preserving relative ratios
top_eigenvalues = top_eigenvalues / max_eigenval
# Step 2: Apply gamma and clamp to safe range
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
# Clamping ensures numerical stability and prevents extreme values
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
# fp16 range: 6e-5 to 65,504, our values are well within this
eigenvectors_fp16 = top_eigenvectors.cpu().half()
eigenvalues_fp16 = top_eigenvalues.cpu().half()
# Cleanup
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
if torch.cuda.is_available():
torch.cuda.empty_cache()
return eigenvectors_fp16, eigenvalues_fp16
def compute_for_batch(
self,
latents_np: np.ndarray,
global_indices: List[int]
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute Γ_b for all points in a batch of same-size latents
Args:
latents_np: (N, d) numpy array
global_indices: List of global dataset indices for each latent
Returns:
Dict mapping global_idx -> (eigenvectors, eigenvalues)
"""
N, d = latents_np.shape
# Validate inputs
if len(global_indices) != N:
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
print(f"Computing CDC for batch: {N} samples, dim={d}")
# Handle small sample cases - require minimum samples for meaningful k-NN
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
if N < MIN_SAMPLES_FOR_CDC:
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
results = {}
for local_idx in range(N):
global_idx = global_indices[local_idx]
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
results[global_idx] = (eigvecs, eigvals)
return results
# Step 1: Build k-NN graph
print(" Building k-NN graph...")
distances, indices = self.compute_knn_graph(latents_np)
# Step 2: Compute bandwidth
# Use min to handle case where k_bw >= actual neighbors returned
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
epsilon = distances[:, k_bw_actual]
# Step 3: Compute Γ_b for each point
results = {}
print(" Computing Γ_b for each point...")
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
global_idx = global_indices[local_idx]
eigvecs, eigvals = self.compute_gamma_b_single(
local_idx, latents_np, distances, indices, epsilon
)
results[global_idx] = (eigvecs, eigvals)
return results
class LatentBatcher:
"""
Collects variable-size latents and batches them by size
"""
def __init__(self, size_tolerance: float = 0.0):
"""
Args:
size_tolerance: If > 0, group latents within tolerance % of size
If 0, only exact size matches are batched
"""
self.size_tolerance = size_tolerance
self.samples: List[LatentSample] = []
def add_sample(self, sample: LatentSample):
"""Add a single latent sample"""
self.samples.append(sample)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a latent vector with automatic shape tracking
Args:
latent: Latent vector (any shape, will be flattened)
global_idx: Global index in dataset
shape: Original shape (if None, uses latent.shape)
metadata: Optional metadata dict
"""
# Convert to numpy and flatten
if isinstance(latent, torch.Tensor):
latent_np = latent.cpu().numpy()
else:
latent_np = latent
original_shape = shape if shape is not None else latent_np.shape
latent_flat = latent_np.flatten()
sample = LatentSample(
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
metadata=metadata
)
self.add_sample(sample)
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
"""
Group samples by exact shape to avoid resizing distortion.
Each bucket contains only samples with identical latent dimensions.
Buckets with fewer than k_neighbors samples will be skipped during CDC
computation and fall back to standard Gaussian noise.
Returns:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
for sample in self.samples:
shape_key = sample.shape
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
batches[shape_key] = []
batches[shape_key].append(sample)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
"""
Get aspect ratio category for grouping.
Groups images by aspect ratio bins to ensure sufficient samples.
For shape (C, H, W), computes aspect ratio H/W and bins it.
"""
if len(shape) < 3:
return "unknown"
# Extract spatial dimensions (H, W)
h, w = shape[-2], shape[-1]
aspect_ratio = h / w
# Define aspect ratio bins (±15% tolerance)
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
bins = [
(0.5, 0.65, "9:16"), # Portrait tall
(0.65, 0.85, "3:4"), # Portrait
(0.85, 1.15, "1:1"), # Square
(1.15, 1.50, "4:3"), # Landscape
(1.50, 2.0, "16:9"), # Landscape wide
(2.0, 3.0, "21:9"), # Ultra wide
]
for min_ratio, max_ratio, label in bins:
if min_ratio <= aspect_ratio < max_ratio:
return label
# Fallback for extreme ratios
if aspect_ratio < 0.5:
return "ultra_tall"
else:
return "ultra_wide"
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
"""Check if two shapes are within tolerance"""
if len(shape1) != len(shape2):
return False
size1 = np.prod(shape1)
size2 = np.prod(shape2)
ratio = abs(size1 - size2) / max(size1, size2)
return ratio <= self.size_tolerance
def __len__(self):
return len(self.samples)
class CDCPreprocessor:
"""
High-level CDC preprocessing coordinator
Handles variable-size latents by batching and delegating to CarreDuChampComputer
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0
):
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device=device
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a single latent to the preprocessing queue
Args:
latent: Latent vector (will be flattened)
global_idx: Global dataset index
shape: Original shape (C, H, W)
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, shape, metadata)
def compute_all(self, save_path: Union[str, Path]) -> Path:
"""
Compute Γ_b for all added latents and save to safetensors
Args:
save_path: Path to save the results
Returns:
Path to saved file
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
# Storage for results
all_results = {}
# Process each bucket
for shape, samples in batches.items():
num_samples = len(samples)
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Check if bucket has enough samples for k-NN
if num_samples < k_neighbors:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
# Collect latents (no resizing needed - all same shape)
latents_list = []
global_indices = []
for sample in samples:
global_indices.append(sample.global_idx)
latents_list.append(sample.latent) # Already flattened
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
# No resizing needed - eigenvectors are already correct size
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
# Merge into overall results
all_results.update(batch_results)
# Save to safetensors
print(f"\n{'='*60}")
print("Saving results...")
print(f"{'='*60}")
tensors_dict = {
'metadata/num_samples': torch.tensor([len(all_results)]),
'metadata/k_neighbors': torch.tensor([self.computer.k]),
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
'metadata/gamma': torch.tensor([self.computer.gamma]),
}
# Add shape information for each sample
for sample in self.batcher.samples:
idx = sample.global_idx
tensors_dict[f'shapes/{idx}'] = torch.tensor(sample.shape)
# Add CDC results (convert numpy to torch tensors)
for global_idx, (eigvecs, eigvals) in all_results.items():
# Convert numpy arrays to torch tensors
if isinstance(eigvecs, np.ndarray):
eigvecs = torch.from_numpy(eigvecs)
if isinstance(eigvals, np.ndarray):
eigvals = torch.from_numpy(eigvals)
tensors_dict[f'eigenvectors/{global_idx}'] = eigvecs
tensors_dict[f'eigenvalues/{global_idx}'] = eigvals
save_file(tensors_dict, save_path)
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
print(f"\nSaved to {save_path}")
print(f"File size: {file_size_gb:.2f} GB")
return save_path
class GammaBDataset:
"""
Efficient loader for Γ_b matrices during training
Handles variable-size latents
"""
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.gamma_b_path = Path(gamma_b_path)
# Load metadata
print(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
indices: Union[List[int], np.ndarray, torch.Tensor],
device: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of indices
Args:
indices: Sample indices
device: Device to load to (defaults to self.device)
Returns:
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
eigenvalues: (B, d_cdc)
"""
if device is None:
device = self.device
# Convert indices to list
if isinstance(indices, torch.Tensor):
indices = indices.cpu().numpy().tolist()
elif isinstance(indices, np.ndarray):
indices = indices.tolist()
# Load from safetensors
from safetensors import safe_open
eigenvectors_list = []
eigenvalues_list = []
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
for idx in indices:
idx = int(idx)
eigvecs = f.get_tensor(f'eigenvectors/{idx}').float()
eigvals = f.get_tensor(f'eigenvalues/{idx}').float()
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
# Check if all eigenvectors have the same dimension
dims = [ev.shape[1] for ev in eigenvectors_list]
if len(set(dims)) > 1:
# Dimension mismatch! This shouldn't happen with proper bucketing
# but can occur if batch contains mixed sizes
raise RuntimeError(
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
f"Batch indices: {indices}. "
f"This means the training batch contains images of different sizes, "
f"which violates CDC's requirement for uniform latent dimensions per batch. "
f"Check that your dataloader buckets are configured correctly."
)
eigenvectors = torch.stack(eigenvectors_list, dim=0)
eigenvalues = torch.stack(eigenvalues_list, dim=0)
return eigenvectors, eigenvalues
def get_shape(self, idx: int) -> Tuple[int, ...]:
"""Get the original shape for a sample"""
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
shape_tensor = f.get_tensor(f'shapes/{idx}')
return tuple(shape_tensor.numpy().tolist())
@torch.no_grad()
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
Args:
eigenvectors: (B, d_cdc, d)
eigenvalues: (B, d_cdc)
x: (B, d) or (B, C, H, W) - will be flattened if needed
t: (B,) or scalar time
Returns:
result: Same shape as input x
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
if t.dim() == 0:
t = t.expand(x.shape[0])
t = t.view(-1, 1)
# Early return for t=0 to avoid numerical errors
if torch.allclose(t, torch.zeros_like(t), atol=1e-8):
return x.reshape(orig_shape)
# Check if CDC is disabled (all eigenvalues are zero)
# This happens for buckets with < k_neighbors samples
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
# Fallback to standard Gaussian noise (no CDC correction)
return x.reshape(orig_shape)
# Γ_b^(1/2) @ x using low-rank representation
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Σ_t @ x
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result

View File

@@ -2,10 +2,8 @@ import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
from accelerate import Accelerator, PartialState
@@ -183,7 +181,7 @@ def sample_image_inference(
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info("negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
@@ -469,8 +467,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, batch_indices=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Args:
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
@@ -514,6 +520,44 @@ def get_noisy_model_input_and_timesteps(
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and batch_indices is not None:
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps / num_timesteps
# Process each sample individually to handle potential dimension mismatches
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
B, C, H, W = noise.shape
noise_transformed = []
for i in range(B):
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
# Get cached shape for this sample
cached_shape = gamma_b_dataset.get_shape(idx)
current_shape = (C, H, W)
if cached_shape != current_shape:
# Shape mismatch - sample was bucketed differently between preprocessing and training
# Use standard Gaussian noise for this sample (no CDC)
logger.warning(
f"CDC shape mismatch for sample {idx}: "
f"cached {cached_shape} vs current {current_shape}. "
f"Using Gaussian noise (no CDC)."
)
noise_transformed.append(noise[i])
else:
# Shapes match - apply CDC transformation
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
noise_flat = noise[i].reshape(1, -1)
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
noise = torch.stack(noise_transformed, dim=0)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:

View File

@@ -1569,11 +1569,19 @@ class BaseDataset(torch.utils.data.Dataset):
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
custom_attributes = []
indices = [] # CDC-FM: track global dataset indices
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
# CDC-FM: Get global index for this image
# Create a sorted list of keys to ensure deterministic indexing
if not hasattr(self, '_image_key_to_index'):
self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))}
global_idx = self._image_key_to_index[image_key]
indices.append(global_idx)
custom_attributes.append(subset.custom_attributes)
# in case of fine tuning, is_reg is always False
@@ -1819,6 +1827,9 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
# CDC-FM: Add global indices to batch
example["indices"] = torch.LongTensor(indices)
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -2690,6 +2701,127 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def cache_cdc_gamma_b(
self,
cdc_output_path: str,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
) -> str:
"""
Cache CDC Γ_b matrices for all latents in the dataset
Args:
cdc_output_path: Path to save cdc_gamma_b.safetensors
k_neighbors: k-NN neighbors
k_bandwidth: Bandwidth estimation neighbors
d_cdc: CDC subspace dimension
gamma: CDC strength
force_recache: Force recompute even if cache exists
accelerator: For multi-GPU support
Returns:
Path to cached CDC file
"""
from pathlib import Path
cdc_path = Path(cdc_output_path)
# Check if valid cache exists
if cdc_path.exists() and not force_recache:
if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma):
logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing")
return str(cdc_path)
else:
logger.info(f"CDC cache found but invalid, will recompute")
# Only main process computes CDC
is_main = accelerator is None or accelerator.is_main_process
if not is_main:
if accelerator is not None:
accelerator.wait_for_everyone()
return str(cdc_path)
logger.info("=" * 60)
logger.info("Starting CDC-FM preprocessing")
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
logger.info("=" * 60)
# Initialize CDC preprocessor
from library.cdc_fm import CDCPreprocessor
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Get caching strategy for loading latents
from library.strategy_base import LatentsCachingStrategy
caching_strategy = LatentsCachingStrategy.get_strategy()
# Collect all latents from all datasets
for dataset_idx, dataset in enumerate(self.datasets):
logger.info(f"Loading latents from dataset {dataset_idx}...")
image_infos = list(dataset.image_data.values())
for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")):
# Load latent from disk or memory
if info.latents is not None:
latent = info.latents
elif info.latents_npz is not None:
# Load from disk
latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso)
if latent is None:
logger.warning(f"Failed to load latent from {info.latents_npz}, skipping")
continue
else:
logger.warning(f"No latent found for {info.absolute_path}, skipping")
continue
# Add to preprocessor (with unique global index across all datasets)
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key})
# Compute and save
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
preprocessor.compute_all(save_path=cdc_path)
if accelerator is not None:
accelerator.wait_for_everyone()
return str(cdc_path)
def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool:
"""Check if CDC cache has matching hyperparameters"""
try:
from safetensors import safe_open
with safe_open(str(cdc_path), framework="pt", device="cpu") as f:
cached_k = int(f.get_tensor("metadata/k_neighbors").item())
cached_d = int(f.get_tensor("metadata/d_cdc").item())
cached_gamma = float(f.get_tensor("metadata/gamma").item())
cached_num = int(f.get_tensor("metadata/num_samples").item())
expected_num = sum(len(d.image_data) for d in self.datasets)
valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num
if not valid:
logger.info(
f"Cache mismatch: k={cached_k} (expected {k_neighbors}), "
f"d_cdc={cached_d} (expected {d_cdc}), "
f"gamma={cached_gamma} (expected {gamma}), "
f"num={cached_num} (expected {expected_num})"
)
return valid
except Exception as e:
logger.warning(f"Error validating CDC cache: {e}")
return False
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)

View File

@@ -0,0 +1,242 @@
"""
Tests to verify CDC eigenvalue scaling is correct.
These tests ensure eigenvalues are properly scaled to prevent training loss explosion.
"""
import numpy as np
import pytest
import torch
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor
class TestEigenvalueScaling:
"""Test that eigenvalues are properly scaled to reasonable ranges"""
def test_eigenvalues_in_correct_range(self, tmp_path):
"""Verify eigenvalues are scaled to ~0.01-1.0 range, not millions"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
# Add deterministic latents with structured patterns
for i in range(10):
# Create gradient pattern: values from 0 to 2.0 across spatial dims
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
# Add per-sample variation
latent = latent + i * 0.1
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify eigenvalues are in correct range
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
# Filter out zero eigenvalues (from padding when k < d_cdc)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# Critical assertions for eigenvalue scale
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
# Check sqrt (used in noise) is reasonable
sqrt_max = np.sqrt(all_eigvals.max())
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
print(f"✓ sqrt(max): {sqrt_max:.4f}")
def test_eigenvalues_not_all_zero(self, tmp_path):
"""Ensure eigenvalues are not all zero (indicating computation failure)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
with safe_open(str(result_path), framework="pt", device="cpu") as f:
all_eigvals = []
for i in range(10):
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
all_eigvals.extend(eigvals)
all_eigvals = np.array(all_eigvals)
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
# With clamping, eigenvalues will be in range [1e-3, gamma*1.0]
# Check that we have some non-zero eigenvalues
assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed"
# Check they're in the expected clamped range
assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}"
assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}"
print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]")
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
def test_fp16_storage_no_overflow(self, tmp_path):
"""Verify fp16 storage doesn't overflow (max fp16 = 65,504)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern with higher magnitude
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
for h in range(8):
for w in range(8):
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
latent = latent + i * 0.3
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check dtype is fp16
eigvecs = f.get_tensor("eigenvectors/0")
eigvals = f.get_tensor("eigenvalues/0")
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
# Check no values near fp16 max (would indicate overflow)
FP16_MAX = 65504
max_eigval = eigvals.max().item()
assert max_eigval < 100, (
f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. "
f"May indicate overflow (fp16 max = {FP16_MAX})"
)
print(f"\n✓ Storage dtype: {eigvals.dtype}")
print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)")
def test_latent_magnitude_preserved(self, tmp_path):
"""Verify latent magnitude is preserved (no unwanted normalization)"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Store original latents with deterministic patterns
original_latents = []
for i in range(10):
# Create structured pattern with known magnitude
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
original_latents.append(latent.clone())
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
# Compute original latent statistics
orig_std = torch.stack(original_latents).std().item()
output_path = tmp_path / "test_gamma_b.safetensors"
preprocessor.compute_all(save_path=output_path)
# The stored latents should preserve original magnitude
stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples])
# Should be similar to original (within 20% due to potential batching effects)
assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, (
f"Stored latent std {stored_latents_std:.2f} differs too much from "
f"original {orig_std:.2f}. Latent magnitude was not preserved."
)
print(f"\n✓ Original latent std: {orig_std:.2f}")
print(f"✓ Stored latent std: {stored_latents_std:.2f}")
class TestTrainingLossScale:
"""Test that eigenvalues produce reasonable loss magnitudes"""
def test_noise_magnitude_reasonable(self, tmp_path):
"""Verify CDC noise has reasonable magnitude for training"""
from library.cdc_fm import GammaBDataset
# Create CDC cache with deterministic data
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
for i in range(10):
# Create deterministic pattern
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
for c in range(16):
for h in range(4):
for w in range(4):
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
output_path = tmp_path / "test_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Load and compute noise
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
# Simulate training scenario with deterministic data
batch_size = 3
latents = torch.zeros(batch_size, 16, 4, 4)
for b in range(batch_size):
for c in range(16):
for h in range(4):
for w in range(4):
latents[b, c, h, w] = (b + c + h + w) / 24.0
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
indices = [0, 5, 9]
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices)
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
# Check noise magnitude
noise_std = noise.std().item()
latent_std = latents.std().item()
# Noise should be similar magnitude to input latents (within 10x)
ratio = noise_std / latent_std
assert 0.1 < ratio < 10.0, (
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
)
# Simulated MSE loss should be reasonable
simulated_loss = torch.mean((noise - latents) ** 2).item()
assert simulated_loss < 100.0, (
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
f"Should be O(0.1-1.0) for stable training."
)
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,164 @@
"""
Test comparing interpolation vs pad/truncate for CDC preprocessing.
This test quantifies the difference between the two approaches.
"""
import numpy as np
import pytest
import torch
import torch.nn.functional as F
class TestInterpolationComparison:
"""Compare interpolation vs pad/truncate"""
def test_intermediate_representation_quality(self):
"""Compare intermediate representation quality for CDC computation"""
# Create test latents with different sizes - deterministic
latent_small = torch.zeros(16, 4, 4)
for c in range(16):
for h in range(4):
for w in range(4):
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
latent_large = torch.zeros(16, 8, 8)
for c in range(16):
for h in range(8):
for w in range(8):
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
target_h, target_w = 6, 6 # Median size
# Method 1: Interpolation
def interpolate_method(latent, target_h, target_w):
latent_input = latent.unsqueeze(0) # (1, C, H, W)
latent_resized = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
)
# Resize back
C, H, W = latent.shape
latent_reconstructed = F.interpolate(
latent_resized, size=(H, W), mode='bilinear', align_corners=False
)
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
return relative_error
# Method 2: Pad/Truncate
def pad_truncate_method(latent, target_h, target_w):
C, H, W = latent.shape
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
current_dim = C * H * W
if current_dim == target_dim:
latent_resized_flat = latent_flat
elif current_dim > target_dim:
# Truncate
latent_resized_flat = latent_flat[:target_dim]
else:
# Pad
latent_resized_flat = torch.zeros(target_dim)
latent_resized_flat[:current_dim] = latent_flat
# Resize back
if current_dim == target_dim:
latent_reconstructed_flat = latent_resized_flat
elif current_dim > target_dim:
# Pad back
latent_reconstructed_flat = torch.zeros(current_dim)
latent_reconstructed_flat[:target_dim] = latent_resized_flat
else:
# Truncate back
latent_reconstructed_flat = latent_resized_flat[:current_dim]
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
return relative_error
# Compare for small latent (needs padding)
interp_error_small = interpolate_method(latent_small, target_h, target_w)
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
# Compare for large latent (needs truncation)
interp_error_large = interpolate_method(latent_large, target_h, target_w)
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
print("\n" + "=" * 60)
print("Reconstruction Error Comparison")
print("=" * 60)
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
print(f" Interpolation error: {interp_error_small:.6f}")
print(f" Pad/truncate error: {pad_error_small:.6f}")
if pad_error_small > 0:
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
else:
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
print(f" BUT the intermediate representation is corrupted with zeros!")
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
print(f" Interpolation error: {interp_error_large:.6f}")
print(f" Pad/truncate error: {truncate_error_large:.6f}")
if truncate_error_large > 0:
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
# The key insight: Reconstruction error is NOT what matters for CDC!
# What matters is the INTERMEDIATE representation quality used for geometry estimation.
# Pad/truncate may have good reconstruction, but the intermediate is corrupted.
print("\nKey insight: For CDC, intermediate representation quality matters,")
print("not reconstruction error. Interpolation preserves spatial structure.")
# Verify interpolation errors are reasonable
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
def test_spatial_structure_preservation(self):
"""Test that interpolation preserves spatial structure better than pad/truncate"""
# Create a latent with clear spatial pattern (gradient)
C, H, W = 16, 4, 4
latent = torch.zeros(C, H, W)
for i in range(H):
for j in range(W):
latent[:, i, j] = i * W + j # Gradient pattern
target_h, target_w = 6, 6
# Interpolation
latent_input = latent.unsqueeze(0)
latent_interp = F.interpolate(
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
).squeeze(0)
# Pad/truncate
latent_flat = latent.reshape(-1)
target_dim = C * target_h * target_w
latent_padded = torch.zeros(target_dim)
latent_padded[:len(latent_flat)] = latent_flat
latent_pad = latent_padded.reshape(C, target_h, target_w)
# Check gradient preservation
# For interpolation, adjacent pixels should have smooth gradients
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
# For padding, there will be abrupt changes (gradient to zero)
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
print("\n" + "=" * 60)
print("Spatial Structure Preservation")
print("=" * 60)
print(f"\nGradient smoothness (lower is smoother):")
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
# Padding introduces larger gradients due to abrupt zeros
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,232 @@
"""
Standalone tests for CDC-FM integration.
These tests focus on CDC-FM specific functionality without importing
the full training infrastructure that has problematic dependencies.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
from safetensors.torch import save_file
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPreprocessor:
"""Test CDC preprocessing functionality"""
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
"""Test basic CDC preprocessing with small dataset"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
# Compute and save
output_path = tmp_path / "test_gamma_b.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify file was created
assert Path(result_path).exists()
# Verify structure
from safetensors import safe_open
with safe_open(str(result_path), framework="pt", device="cpu") as f:
assert f.get_tensor("metadata/num_samples").item() == 10
assert f.get_tensor("metadata/k_neighbors").item() == 5
assert f.get_tensor("metadata/d_cdc").item() == 4
# Check first sample
eigvecs = f.get_tensor("eigenvectors/0")
eigvals = f.get_tensor("eigenvalues/0")
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_cdc_preprocessor_different_shapes(self, tmp_path):
"""Test CDC preprocessing with variable-size latents (bucketing)"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
# Compute and save
output_path = tmp_path / "test_gamma_b_multi.safetensors"
result_path = preprocessor.compute_all(save_path=output_path)
# Verify both shape groups were processed
from safetensors import safe_open
with safe_open(str(result_path), framework="pt", device="cpu") as f:
# Check shapes are stored
shape_0 = f.get_tensor("shapes/0")
shape_5 = f.get_tensor("shapes/5")
assert tuple(shape_0.tolist()) == (16, 4, 4)
assert tuple(shape_5.tolist()) == (16, 8, 8)
class TestGammaBDataset:
"""Test GammaBDataset loading and retrieval"""
@pytest.fixture
def sample_cdc_cache(self, tmp_path):
"""Create a sample CDC cache file for testing"""
cache_path = tmp_path / "test_gamma_b.safetensors"
# Create mock Γ_b data for 5 samples
tensors = {
"metadata/num_samples": torch.tensor([5]),
"metadata/k_neighbors": torch.tensor([10]),
"metadata/d_cdc": torch.tensor([4]),
"metadata/gamma": torch.tensor([1.0]),
}
# Add shape and CDC data for each sample
for i in range(5):
tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W
tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d
tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive
save_file(tensors, str(cache_path))
return cache_path
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
"""Test that GammaBDataset loads metadata correctly"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
assert gamma_b_dataset.num_samples == 5
assert gamma_b_dataset.d_cdc == 4
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
"""Test retrieving Γ_b^(1/2) components"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
# Get Γ_b for indices [0, 2, 4]
indices = [0, 2, 4]
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu")
# Check shapes
assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d)
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
# Check values are positive
assert torch.all(eigenvalues > 0)
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns x unchanged at t=0"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
# Create test latents (batch of 3, matching d=1024 flattened)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.zeros(3) # t = 0 for all samples
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# At t=0, should return x unchanged
assert torch.allclose(sigma_t_x, x, atol=1e-6)
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns correct shape"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
x = torch.randn(2, 1024) # B, d (flattened)
t = torch.tensor([0.3, 0.7])
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should return same shape as input
assert sigma_t_x.shape == x.shape
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
"""Test compute_sigma_t_x produces finite values"""
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.rand(3) # Random timesteps in [0, 1]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should not contain NaNs or Infs
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
class TestCDCEndToEnd:
"""End-to-end CDC workflow tests"""
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
"""Test complete workflow: preprocess -> save -> load -> use"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
)
num_samples = 10
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
output_path = tmp_path / "cdc_gamma_b.safetensors"
cdc_path = preprocessor.compute_all(save_path=output_path)
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
assert gamma_b_dataset.num_samples == num_samples
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
batch_indices = [0, 5, 9]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -622,6 +622,23 @@ class NetworkTrainer:
accelerator.wait_for_everyone()
# CDC-FM preprocessing
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors")
self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b(
cdc_output_path=cdc_output_path,
k_neighbors=args.cdc_k_neighbors,
k_bandwidth=args.cdc_k_bandwidth,
d_cdc=args.cdc_d_cdc,
gamma=args.cdc_gamma,
force_recache=args.force_recache_cdc,
accelerator=accelerator,
)
else:
self.cdc_cache_path = None
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
text_encoding_strategy = self.get_text_encoding_strategy(args)
@@ -634,7 +651,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is None:
if unet is none:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
@@ -643,10 +660,10 @@ class NetworkTrainer:
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
if args.base_weights is not none:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
@@ -660,6 +677,17 @@ class NetworkTrainer:
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# Load CDC-FM Γ_b dataset if enabled
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
from library.cdc_fm import GammaBDataset
logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}")
self.gamma_b_dataset = GammaBDataset(
gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.gamma_b_dataset = None
# prepare network
net_kwargs = {}
if args.network_args is not None: