From f552f9a3bdfe01281f9acc5f134cc513f2fbdb14 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Oct 2025 15:18:43 -0400 Subject: [PATCH] =?UTF-8?q?Add=20CDC-FM=20(Carr=C3=A9=20du=20Champ=20Flow?= =?UTF-8?q?=20Matching)=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements geometry-aware noise generation for FLUX training based on arXiv:2510.05930v1. --- flux_train_network.py | 58 +- library/cdc_fm.py | 712 ++++++++++++++++++ library/flux_train_utils.py | 54 +- library/train_util.py | 132 ++++ tests/library/test_cdc_eigenvalue_scaling.py | 242 ++++++ .../test_cdc_interpolation_comparison.py | 164 ++++ tests/library/test_cdc_standalone.py | 232 ++++++ train_network.py | 34 +- 8 files changed, 1615 insertions(+), 13 deletions(-) create mode 100644 library/cdc_fm.py create mode 100644 tests/library/test_cdc_eigenvalue_scaling.py create mode 100644 tests/library/test_cdc_interpolation_comparison.py create mode 100644 tests/library/test_cdc_standalone.py diff --git a/flux_train_network.py b/flux_train_network.py index cfc61708..48c0fbc9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -1,7 +1,5 @@ import argparse import copy -import math -import random from typing import Any, Optional, Union import torch @@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False self.model_type: Optional[str] = None + self.gamma_b_dataset = None # CDC-FM Γ_b dataset def assert_extra_args( self, @@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): noise = torch.randn_like(latents) bsz = latents.shape[0] - # get noisy model input and timesteps + # Get CDC parameters if enabled + gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "indices" in batch) else None + batch_indices = batch.get("indices") if gamma_b_dataset is not None else None + + # Get noisy model input and timesteps + # If CDC is enabled, this will transform the noise with geometry-aware covariance noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype, + gamma_b_dataset=gamma_b_dataset, batch_indices=batch_indices ) # pack latents and get img_ids @@ -494,7 +499,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): module.forward = forward_hook(module) if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") + logger.info("T5XXL already prepared for fp8") else: logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") text_encoder.to(te_weight_dtype) # fp8 @@ -533,6 +538,49 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", ) + + # CDC-FM arguments + parser.add_argument( + "--use_cdc_fm", + action="store_true", + help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training" + " / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用", + ) + parser.add_argument( + "--cdc_k_neighbors", + type=int, + default=256, + help="Number of neighbors for k-NN graph in CDC-FM (default: 256)" + " / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)", + ) + parser.add_argument( + "--cdc_k_bandwidth", + type=int, + default=8, + help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)" + " / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)", + ) + parser.add_argument( + "--cdc_d_cdc", + type=int, + default=8, + help="Dimension of CDC subspace (default: 8)" + " / CDCサブ空間の次元(デフォルト: 8)", + ) + parser.add_argument( + "--cdc_gamma", + type=float, + default=1.0, + help="CDC strength parameter (default: 1.0)" + " / CDC強度パラメータ(デフォルト: 1.0)", + ) + parser.add_argument( + "--force_recache_cdc", + action="store_true", + help="Force recompute CDC cache even if valid cache exists" + " / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算", + ) + return parser diff --git a/library/cdc_fm.py b/library/cdc_fm.py new file mode 100644 index 00000000..ca9f6e81 --- /dev/null +++ b/library/cdc_fm.py @@ -0,0 +1,712 @@ +import logging +import torch +import numpy as np +import faiss # type: ignore +from pathlib import Path +from tqdm import tqdm +from safetensors.torch import save_file +from typing import List, Dict, Optional, Union, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class LatentSample: + """ + Container for a single latent with metadata + """ + latent: np.ndarray # (d,) flattened latent vector + global_idx: int # Global index in dataset + shape: Tuple[int, ...] # Original shape before flattening (C, H, W) + metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.) + + +class CarreDuChampComputer: + """ + Core CDC-FM computation - agnostic to data source + Just handles the math for a batch of same-size latents + """ + + def __init__( + self, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + device: str = 'cuda' + ): + self.k = k_neighbors + self.k_bw = k_bandwidth + self.d_cdc = d_cdc + self.gamma = gamma + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + + def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Build k-NN graph using FAISS + + Args: + latents_np: (N, d) numpy array of same-dimensional latents + + Returns: + distances: (N, k_actual+1) distances (k_actual may be less than k if N is small) + indices: (N, k_actual+1) neighbor indices + """ + N, d = latents_np.shape + + # Clamp k to available neighbors (can't have more neighbors than samples) + k_actual = min(self.k, N - 1) + + # Ensure float32 + if latents_np.dtype != np.float32: + latents_np = latents_np.astype(np.float32) + + # Build FAISS index + index = faiss.IndexFlatL2(d) + + if torch.cuda.is_available(): + res = faiss.StandardGpuResources() + index = faiss.index_cpu_to_gpu(res, 0, index) + + index.add(latents_np) # type: ignore + distances, indices = index.search(latents_np, k_actual + 1) # type: ignore + + return distances, indices + + @torch.no_grad() + def compute_gamma_b_single( + self, + point_idx: int, + latents_np: np.ndarray, + distances: np.ndarray, + indices: np.ndarray, + epsilon: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute Γ_b for a single point + + Args: + point_idx: Index of point to process + latents_np: (N, d) all latents in this batch + distances: (N, k+1) precomputed distances + indices: (N, k+1) precomputed neighbor indices + epsilon: (N,) bandwidth per point + + Returns: + eigenvectors: (d_cdc, d) as half precision tensor + eigenvalues: (d_cdc,) as half precision tensor + """ + d = latents_np.shape[1] + + # Get neighbors (exclude self) + neighbor_idx = indices[point_idx, 1:] # (k,) + neighbor_points = latents_np[neighbor_idx] # (k, d) + + # Clamp distances to prevent overflow (max realistic L2 distance) + MAX_DISTANCE = 1e10 + neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE) + neighbor_dists_sq = neighbor_dists ** 2 # (k,) + + # Compute Gaussian kernel weights with numerical guards + eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero + eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10) + + # Compute denominator with guard against overflow + denom = eps_i * eps_neighbors + denom = np.maximum(denom, 1e-20) # Additional guard + + # Compute weights with safe exponential + exp_arg = -neighbor_dists_sq / denom + exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow + weights = np.exp(exp_arg) + + # Normalize weights, handle edge case of all zeros + weight_sum = weights.sum() + if weight_sum < 1e-20 or not np.isfinite(weight_sum): + # Fallback to uniform weights + weights = np.ones_like(weights) / len(weights) + else: + weights = weights / weight_sum + + # Compute local mean + m_star = np.sum(weights[:, None] * neighbor_points, axis=0) + + # Center and weight for SVD + centered = neighbor_points - m_star + weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d) + + # Validate input is finite before SVD + if not np.all(np.isfinite(weighted_centered)): + logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback") + # Fallback: use uniform weights and simple centering + weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points) + m_star = np.mean(neighbor_points, axis=0) + centered = neighbor_points - m_star + weighted_centered = np.sqrt(weights_uniform)[:, None] * centered + + # Move to GPU for SVD (100x speedup!) + weighted_centered_torch = torch.from_numpy(weighted_centered).to( + self.device, dtype=torch.float32 + ) + + try: + U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False) + except RuntimeError as e: + logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}") + try: + U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False) + U = torch.from_numpy(U).to(self.device) + S = torch.from_numpy(S).to(self.device) + Vh = torch.from_numpy(Vh).to(self.device) + except np.linalg.LinAlgError as e2: + logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix") + # Return zero eigenvalues/vectors as fallback + return ( + torch.zeros(self.d_cdc, d, dtype=torch.float16), + torch.zeros(self.d_cdc, dtype=torch.float16) + ) + + # Eigenvalues of Γ_b + eigenvalues_full = S ** 2 + + # Keep top d_cdc + if len(eigenvalues_full) >= self.d_cdc: + top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc) + top_eigenvectors = Vh[top_idx, :] # (d_cdc, d) + else: + # Pad if k < d_cdc + top_eigenvalues = eigenvalues_full + top_eigenvectors = Vh + if len(eigenvalues_full) < self.d_cdc: + pad_size = self.d_cdc - len(eigenvalues_full) + top_eigenvalues = torch.cat([ + top_eigenvalues, + torch.zeros(pad_size, device=self.device) + ]) + top_eigenvectors = torch.cat([ + top_eigenvectors, + torch.zeros(pad_size, d, device=self.device) + ]) + + # Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33) + # Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max) + # Then apply gamma: γc_i Γ̂(x^(i)) + # + # Our implementation: + # 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor + # 2. Apply gamma hyperparameter - aligns with paper's γ global scaling + # 3. Clamp for numerical stability + # + # Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents) + # Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound + + # Step 1: Normalize by the maximum eigenvalue to get relative scales + # This is the paper's 1/λ_1^i normalization factor + max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0 + + if max_eigenval > 1e-10: + # Scale so max eigenvalue = 1.0, preserving relative ratios + top_eigenvalues = top_eigenvalues / max_eigenval + + # Step 2: Apply gamma and clamp to safe range + # Gamma is the paper's tuneable hyperparameter (defaults to 1.0) + # Clamping ensures numerical stability and prevents extreme values + top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0) + + # Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0 + # fp16 range: 6e-5 to 65,504, our values are well within this + eigenvectors_fp16 = top_eigenvectors.cpu().half() + eigenvalues_fp16 = top_eigenvalues.cpu().half() + + # Cleanup + del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return eigenvectors_fp16, eigenvalues_fp16 + + def compute_for_batch( + self, + latents_np: np.ndarray, + global_indices: List[int] + ) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]: + """ + Compute Γ_b for all points in a batch of same-size latents + + Args: + latents_np: (N, d) numpy array + global_indices: List of global dataset indices for each latent + + Returns: + Dict mapping global_idx -> (eigenvectors, eigenvalues) + """ + N, d = latents_np.shape + + # Validate inputs + if len(global_indices) != N: + raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices") + + print(f"Computing CDC for batch: {N} samples, dim={d}") + + # Handle small sample cases - require minimum samples for meaningful k-NN + MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation + + if N < MIN_SAMPLES_FOR_CDC: + print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)") + results = {} + for local_idx in range(N): + global_idx = global_indices[local_idx] + # Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x) + eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16) + eigvals = np.zeros(self.d_cdc, dtype=np.float16) + results[global_idx] = (eigvecs, eigvals) + return results + + # Step 1: Build k-NN graph + print(" Building k-NN graph...") + distances, indices = self.compute_knn_graph(latents_np) + + # Step 2: Compute bandwidth + # Use min to handle case where k_bw >= actual neighbors returned + k_bw_actual = min(self.k_bw, distances.shape[1] - 1) + epsilon = distances[:, k_bw_actual] + + # Step 3: Compute Γ_b for each point + results = {} + print(" Computing Γ_b for each point...") + for local_idx in tqdm(range(N), desc=" Processing", leave=False): + global_idx = global_indices[local_idx] + eigvecs, eigvals = self.compute_gamma_b_single( + local_idx, latents_np, distances, indices, epsilon + ) + results[global_idx] = (eigvecs, eigvals) + + return results + + +class LatentBatcher: + """ + Collects variable-size latents and batches them by size + """ + + def __init__(self, size_tolerance: float = 0.0): + """ + Args: + size_tolerance: If > 0, group latents within tolerance % of size + If 0, only exact size matches are batched + """ + self.size_tolerance = size_tolerance + self.samples: List[LatentSample] = [] + + def add_sample(self, sample: LatentSample): + """Add a single latent sample""" + self.samples.append(sample) + + def add_latent( + self, + latent: Union[np.ndarray, torch.Tensor], + global_idx: int, + shape: Optional[Tuple[int, ...]] = None, + metadata: Optional[Dict] = None + ): + """ + Add a latent vector with automatic shape tracking + + Args: + latent: Latent vector (any shape, will be flattened) + global_idx: Global index in dataset + shape: Original shape (if None, uses latent.shape) + metadata: Optional metadata dict + """ + # Convert to numpy and flatten + if isinstance(latent, torch.Tensor): + latent_np = latent.cpu().numpy() + else: + latent_np = latent + + original_shape = shape if shape is not None else latent_np.shape + latent_flat = latent_np.flatten() + + sample = LatentSample( + latent=latent_flat, + global_idx=global_idx, + shape=original_shape, + metadata=metadata + ) + + self.add_sample(sample) + + def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]: + """ + Group samples by exact shape to avoid resizing distortion. + + Each bucket contains only samples with identical latent dimensions. + Buckets with fewer than k_neighbors samples will be skipped during CDC + computation and fall back to standard Gaussian noise. + + Returns: + Dict mapping exact_shape -> list of samples with that shape + """ + batches = {} + + for sample in self.samples: + shape_key = sample.shape + + # Group by exact shape only - no aspect ratio grouping or resizing + if shape_key not in batches: + batches[shape_key] = [] + + batches[shape_key].append(sample) + + return batches + + def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str: + """ + Get aspect ratio category for grouping. + Groups images by aspect ratio bins to ensure sufficient samples. + + For shape (C, H, W), computes aspect ratio H/W and bins it. + """ + if len(shape) < 3: + return "unknown" + + # Extract spatial dimensions (H, W) + h, w = shape[-2], shape[-1] + aspect_ratio = h / w + + # Define aspect ratio bins (±15% tolerance) + # Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16) + bins = [ + (0.5, 0.65, "9:16"), # Portrait tall + (0.65, 0.85, "3:4"), # Portrait + (0.85, 1.15, "1:1"), # Square + (1.15, 1.50, "4:3"), # Landscape + (1.50, 2.0, "16:9"), # Landscape wide + (2.0, 3.0, "21:9"), # Ultra wide + ] + + for min_ratio, max_ratio, label in bins: + if min_ratio <= aspect_ratio < max_ratio: + return label + + # Fallback for extreme ratios + if aspect_ratio < 0.5: + return "ultra_tall" + else: + return "ultra_wide" + + def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool: + """Check if two shapes are within tolerance""" + if len(shape1) != len(shape2): + return False + + size1 = np.prod(shape1) + size2 = np.prod(shape2) + + ratio = abs(size1 - size2) / max(size1, size2) + return ratio <= self.size_tolerance + + def __len__(self): + return len(self.samples) + + +class CDCPreprocessor: + """ + High-level CDC preprocessing coordinator + Handles variable-size latents by batching and delegating to CarreDuChampComputer + """ + + def __init__( + self, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + device: str = 'cuda', + size_tolerance: float = 0.0 + ): + self.computer = CarreDuChampComputer( + k_neighbors=k_neighbors, + k_bandwidth=k_bandwidth, + d_cdc=d_cdc, + gamma=gamma, + device=device + ) + self.batcher = LatentBatcher(size_tolerance=size_tolerance) + + def add_latent( + self, + latent: Union[np.ndarray, torch.Tensor], + global_idx: int, + shape: Optional[Tuple[int, ...]] = None, + metadata: Optional[Dict] = None + ): + """ + Add a single latent to the preprocessing queue + + Args: + latent: Latent vector (will be flattened) + global_idx: Global dataset index + shape: Original shape (C, H, W) + metadata: Optional metadata + """ + self.batcher.add_latent(latent, global_idx, shape, metadata) + + def compute_all(self, save_path: Union[str, Path]) -> Path: + """ + Compute Γ_b for all added latents and save to safetensors + + Args: + save_path: Path to save the results + + Returns: + Path to saved file + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Get batches by exact size (no resizing) + batches = self.batcher.get_batches() + + print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets") + + # Count samples that will get CDC vs fallback + k_neighbors = self.computer.k + samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors) + samples_fallback = len(self.batcher) - samples_with_cdc + + print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)") + print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)") + + # Storage for results + all_results = {} + + # Process each bucket + for shape, samples in batches.items(): + num_samples = len(samples) + + print(f"\n{'='*60}") + print(f"Bucket: {shape} ({num_samples} samples)") + print(f"{'='*60}") + + # Check if bucket has enough samples for k-NN + if num_samples < k_neighbors: + print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}") + print(" → These samples will use standard Gaussian noise (no CDC)") + + # Store zero eigenvectors/eigenvalues (Gaussian fallback) + C, H, W = shape + d = C * H * W + + for sample in samples: + eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16) + eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16) + all_results[sample.global_idx] = (eigvecs, eigvals) + + continue + + # Collect latents (no resizing needed - all same shape) + latents_list = [] + global_indices = [] + + for sample in samples: + global_indices.append(sample.global_idx) + latents_list.append(sample.latent) # Already flattened + + latents_np = np.stack(latents_list, axis=0) # (N, C*H*W) + + # Compute CDC for this batch + print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}") + batch_results = self.computer.compute_for_batch(latents_np, global_indices) + + # No resizing needed - eigenvectors are already correct size + print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)") + + # Merge into overall results + all_results.update(batch_results) + + # Save to safetensors + print(f"\n{'='*60}") + print("Saving results...") + print(f"{'='*60}") + + tensors_dict = { + 'metadata/num_samples': torch.tensor([len(all_results)]), + 'metadata/k_neighbors': torch.tensor([self.computer.k]), + 'metadata/d_cdc': torch.tensor([self.computer.d_cdc]), + 'metadata/gamma': torch.tensor([self.computer.gamma]), + } + + # Add shape information for each sample + for sample in self.batcher.samples: + idx = sample.global_idx + tensors_dict[f'shapes/{idx}'] = torch.tensor(sample.shape) + + # Add CDC results (convert numpy to torch tensors) + for global_idx, (eigvecs, eigvals) in all_results.items(): + # Convert numpy arrays to torch tensors + if isinstance(eigvecs, np.ndarray): + eigvecs = torch.from_numpy(eigvecs) + if isinstance(eigvals, np.ndarray): + eigvals = torch.from_numpy(eigvals) + + tensors_dict[f'eigenvectors/{global_idx}'] = eigvecs + tensors_dict[f'eigenvalues/{global_idx}'] = eigvals + + save_file(tensors_dict, save_path) + + file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024 + print(f"\nSaved to {save_path}") + print(f"File size: {file_size_gb:.2f} GB") + + return save_path + + +class GammaBDataset: + """ + Efficient loader for Γ_b matrices during training + Handles variable-size latents + """ + + def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'): + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + self.gamma_b_path = Path(gamma_b_path) + + # Load metadata + print(f"Loading Γ_b from {gamma_b_path}...") + from safetensors import safe_open + + with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f: + self.num_samples = int(f.get_tensor('metadata/num_samples').item()) + self.d_cdc = int(f.get_tensor('metadata/d_cdc').item()) + + print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})") + + @torch.no_grad() + def get_gamma_b_sqrt( + self, + indices: Union[List[int], np.ndarray, torch.Tensor], + device: Optional[str] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get Γ_b^(1/2) components for a batch of indices + + Args: + indices: Sample indices + device: Device to load to (defaults to self.device) + + Returns: + eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample! + eigenvalues: (B, d_cdc) + """ + if device is None: + device = self.device + + # Convert indices to list + if isinstance(indices, torch.Tensor): + indices = indices.cpu().numpy().tolist() + elif isinstance(indices, np.ndarray): + indices = indices.tolist() + + # Load from safetensors + from safetensors import safe_open + + eigenvectors_list = [] + eigenvalues_list = [] + + with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f: + for idx in indices: + idx = int(idx) + eigvecs = f.get_tensor(f'eigenvectors/{idx}').float() + eigvals = f.get_tensor(f'eigenvalues/{idx}').float() + + eigenvectors_list.append(eigvecs) + eigenvalues_list.append(eigvals) + + # Stack - all should have same d_cdc and d within a batch (enforced by bucketing) + # Check if all eigenvectors have the same dimension + dims = [ev.shape[1] for ev in eigenvectors_list] + if len(set(dims)) > 1: + # Dimension mismatch! This shouldn't happen with proper bucketing + # but can occur if batch contains mixed sizes + raise RuntimeError( + f"CDC eigenvector dimension mismatch in batch: {set(dims)}. " + f"Batch indices: {indices}. " + f"This means the training batch contains images of different sizes, " + f"which violates CDC's requirement for uniform latent dimensions per batch. " + f"Check that your dataloader buckets are configured correctly." + ) + + eigenvectors = torch.stack(eigenvectors_list, dim=0) + eigenvalues = torch.stack(eigenvalues_list, dim=0) + + return eigenvectors, eigenvalues + + def get_shape(self, idx: int) -> Tuple[int, ...]: + """Get the original shape for a sample""" + from safetensors import safe_open + + with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f: + shape_tensor = f.get_tensor(f'shapes/{idx}') + return tuple(shape_tensor.numpy().tolist()) + + @torch.no_grad() + def compute_sigma_t_x( + self, + eigenvectors: torch.Tensor, + eigenvalues: torch.Tensor, + x: torch.Tensor, + t: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2) + + Args: + eigenvectors: (B, d_cdc, d) + eigenvalues: (B, d_cdc) + x: (B, d) or (B, C, H, W) - will be flattened if needed + t: (B,) or scalar time + + Returns: + result: Same shape as input x + """ + # Store original shape to restore later + orig_shape = x.shape + + # Flatten x if it's 4D + if x.dim() == 4: + B, C, H, W = x.shape + x = x.reshape(B, -1) # (B, C*H*W) + + if not isinstance(t, torch.Tensor): + t = torch.tensor(t, device=x.device, dtype=x.dtype) + + if t.dim() == 0: + t = t.expand(x.shape[0]) + + t = t.view(-1, 1) + + # Early return for t=0 to avoid numerical errors + if torch.allclose(t, torch.zeros_like(t), atol=1e-8): + return x.reshape(orig_shape) + + # Check if CDC is disabled (all eigenvalues are zero) + # This happens for buckets with < k_neighbors samples + if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8): + # Fallback to standard Gaussian noise (no CDC correction) + return x.reshape(orig_shape) + + # Γ_b^(1/2) @ x using low-rank representation + Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x) + sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10)) + sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x + gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x) + + # Σ_t @ x + result = (1 - t) * x + t * gamma_sqrt_x + + # Restore original shape + result = result.reshape(orig_shape) + + return result diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 06fe0b95..b40a1654 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -2,10 +2,8 @@ import argparse import math import os import numpy as np -import toml -import json import time -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple import torch from accelerate import Accelerator, PartialState @@ -183,7 +181,7 @@ def sample_image_inference( if cfg_scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") elif negative_prompt != "": - logger.info(f"negative prompt is ignored because scale is 1.0") + logger.info("negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") @@ -469,8 +467,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype, + gamma_b_dataset=None, batch_indices=None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Get noisy model input and timesteps for training. + + Args: + gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise + batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided) + """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps @@ -514,6 +520,44 @@ def get_noisy_model_input_and_timesteps( # Broadcast sigmas to latent shape sigmas = sigmas.view(-1, 1, 1, 1) + # Apply CDC-FM geometry-aware noise transformation if enabled + if gamma_b_dataset is not None and batch_indices is not None: + # Normalize timesteps to [0, 1] for CDC-FM + t_normalized = timesteps / num_timesteps + + # Process each sample individually to handle potential dimension mismatches + # (can happen with multi-subset training where bucketing differs between preprocessing and training) + B, C, H, W = noise.shape + noise_transformed = [] + + for i in range(B): + idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i] + + # Get cached shape for this sample + cached_shape = gamma_b_dataset.get_shape(idx) + current_shape = (C, H, W) + + if cached_shape != current_shape: + # Shape mismatch - sample was bucketed differently between preprocessing and training + # Use standard Gaussian noise for this sample (no CDC) + logger.warning( + f"CDC shape mismatch for sample {idx}: " + f"cached {cached_shape} vs current {current_shape}. " + f"Using Gaussian noise (no CDC)." + ) + noise_transformed.append(noise[i]) + else: + # Shapes match - apply CDC transformation + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device) + + noise_flat = noise[i].reshape(1, -1) + t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized + + noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single) + noise_transformed.append(noise_cdc_flat.reshape(C, H, W)) + + noise = torch.stack(noise_transformed, dim=0) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: diff --git a/library/train_util.py b/library/train_util.py index 756d88b1..bb47a846 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1569,11 +1569,19 @@ class BaseDataset(torch.utils.data.Dataset): flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] custom_attributes = [] + indices = [] # CDC-FM: track global dataset indices for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + # CDC-FM: Get global index for this image + # Create a sorted list of keys to ensure deterministic indexing + if not hasattr(self, '_image_key_to_index'): + self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))} + global_idx = self._image_key_to_index[image_key] + indices.append(global_idx) + custom_attributes.append(subset.custom_attributes) # in case of fine tuning, is_reg is always False @@ -1819,6 +1827,9 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + # CDC-FM: Add global indices to batch + example["indices"] = torch.LongTensor(indices) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -2690,6 +2701,127 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.new_cache_text_encoder_outputs(models, accelerator) accelerator.wait_for_everyone() + def cache_cdc_gamma_b( + self, + cdc_output_path: str, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + force_recache: bool = False, + accelerator: Optional["Accelerator"] = None, + ) -> str: + """ + Cache CDC Γ_b matrices for all latents in the dataset + + Args: + cdc_output_path: Path to save cdc_gamma_b.safetensors + k_neighbors: k-NN neighbors + k_bandwidth: Bandwidth estimation neighbors + d_cdc: CDC subspace dimension + gamma: CDC strength + force_recache: Force recompute even if cache exists + accelerator: For multi-GPU support + + Returns: + Path to cached CDC file + """ + from pathlib import Path + + cdc_path = Path(cdc_output_path) + + # Check if valid cache exists + if cdc_path.exists() and not force_recache: + if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma): + logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing") + return str(cdc_path) + else: + logger.info(f"CDC cache found but invalid, will recompute") + + # Only main process computes CDC + is_main = accelerator is None or accelerator.is_main_process + if not is_main: + if accelerator is not None: + accelerator.wait_for_everyone() + return str(cdc_path) + + logger.info("=" * 60) + logger.info("Starting CDC-FM preprocessing") + logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}") + logger.info("=" * 60) + + # Initialize CDC preprocessor + from library.cdc_fm import CDCPreprocessor + + preprocessor = CDCPreprocessor( + k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu" + ) + + # Get caching strategy for loading latents + from library.strategy_base import LatentsCachingStrategy + + caching_strategy = LatentsCachingStrategy.get_strategy() + + # Collect all latents from all datasets + for dataset_idx, dataset in enumerate(self.datasets): + logger.info(f"Loading latents from dataset {dataset_idx}...") + image_infos = list(dataset.image_data.values()) + + for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")): + # Load latent from disk or memory + if info.latents is not None: + latent = info.latents + elif info.latents_npz is not None: + # Load from disk + latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso) + if latent is None: + logger.warning(f"Failed to load latent from {info.latents_npz}, skipping") + continue + else: + logger.warning(f"No latent found for {info.absolute_path}, skipping") + continue + + # Add to preprocessor (with unique global index across all datasets) + actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx + preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key}) + + # Compute and save + logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...") + preprocessor.compute_all(save_path=cdc_path) + + if accelerator is not None: + accelerator.wait_for_everyone() + + return str(cdc_path) + + def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool: + """Check if CDC cache has matching hyperparameters""" + try: + from safetensors import safe_open + + with safe_open(str(cdc_path), framework="pt", device="cpu") as f: + cached_k = int(f.get_tensor("metadata/k_neighbors").item()) + cached_d = int(f.get_tensor("metadata/d_cdc").item()) + cached_gamma = float(f.get_tensor("metadata/gamma").item()) + cached_num = int(f.get_tensor("metadata/num_samples").item()) + + expected_num = sum(len(d.image_data) for d in self.datasets) + + valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num + + if not valid: + logger.info( + f"Cache mismatch: k={cached_k} (expected {k_neighbors}), " + f"d_cdc={cached_d} (expected {d_cdc}), " + f"gamma={cached_gamma} (expected {gamma}), " + f"num={cached_num} (expected {expected_num})" + ) + + return valid + except Exception as e: + logger.warning(f"Error validating CDC cache: {e}") + return False + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) diff --git a/tests/library/test_cdc_eigenvalue_scaling.py b/tests/library/test_cdc_eigenvalue_scaling.py new file mode 100644 index 00000000..65dcadd9 --- /dev/null +++ b/tests/library/test_cdc_eigenvalue_scaling.py @@ -0,0 +1,242 @@ +""" +Tests to verify CDC eigenvalue scaling is correct. + +These tests ensure eigenvalues are properly scaled to prevent training loss explosion. +""" + +import numpy as np +import pytest +import torch +from safetensors import safe_open + +from library.cdc_fm import CDCPreprocessor + + +class TestEigenvalueScaling: + """Test that eigenvalues are properly scaled to reasonable ranges""" + + def test_eigenvalues_in_correct_range(self, tmp_path): + """Verify eigenvalues are scaled to ~0.01-1.0 range, not millions""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + ) + + # Add deterministic latents with structured patterns + for i in range(10): + # Create gradient pattern: values from 0 to 2.0 across spatial dims + latent = torch.zeros(16, 8, 8, dtype=torch.float32) + for h in range(8): + for w in range(8): + latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0] + # Add per-sample variation + latent = latent + i * 0.1 + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + output_path = tmp_path / "test_gamma_b.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + # Verify eigenvalues are in correct range + with safe_open(str(result_path), framework="pt", device="cpu") as f: + all_eigvals = [] + for i in range(10): + eigvals = f.get_tensor(f"eigenvalues/{i}").numpy() + all_eigvals.extend(eigvals) + + all_eigvals = np.array(all_eigvals) + + # Filter out zero eigenvalues (from padding when k < d_cdc) + non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] + + # Critical assertions for eigenvalue scale + assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)" + assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues" + assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large" + + # Check sqrt (used in noise) is reasonable + sqrt_max = np.sqrt(all_eigvals.max()) + assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion" + + print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]") + print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}") + print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}") + print(f"✓ sqrt(max): {sqrt_max:.4f}") + + def test_eigenvalues_not_all_zero(self, tmp_path): + """Ensure eigenvalues are not all zero (indicating computation failure)""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + for i in range(10): + # Create deterministic pattern + latent = torch.zeros(16, 4, 4, dtype=torch.float32) + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2 + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + output_path = tmp_path / "test_gamma_b.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + with safe_open(str(result_path), framework="pt", device="cpu") as f: + all_eigvals = [] + for i in range(10): + eigvals = f.get_tensor(f"eigenvalues/{i}").numpy() + all_eigvals.extend(eigvals) + + all_eigvals = np.array(all_eigvals) + non_zero_eigvals = all_eigvals[all_eigvals > 1e-6] + + # With clamping, eigenvalues will be in range [1e-3, gamma*1.0] + # Check that we have some non-zero eigenvalues + assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed" + + # Check they're in the expected clamped range + assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}" + assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}" + + print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}") + print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]") + print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}") + + def test_fp16_storage_no_overflow(self, tmp_path): + """Verify fp16 storage doesn't overflow (max fp16 = 65,504)""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + ) + + for i in range(10): + # Create deterministic pattern with higher magnitude + latent = torch.zeros(16, 8, 8, dtype=torch.float32) + for h in range(8): + for w in range(8): + latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0] + latent = latent + i * 0.3 + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + output_path = tmp_path / "test_gamma_b.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + with safe_open(str(result_path), framework="pt", device="cpu") as f: + # Check dtype is fp16 + eigvecs = f.get_tensor("eigenvectors/0") + eigvals = f.get_tensor("eigenvalues/0") + + assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}" + assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}" + + # Check no values near fp16 max (would indicate overflow) + FP16_MAX = 65504 + max_eigval = eigvals.max().item() + + assert max_eigval < 100, ( + f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. " + f"May indicate overflow (fp16 max = {FP16_MAX})" + ) + + print(f"\n✓ Storage dtype: {eigvals.dtype}") + print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)") + + def test_latent_magnitude_preserved(self, tmp_path): + """Verify latent magnitude is preserved (no unwanted normalization)""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + # Store original latents with deterministic patterns + original_latents = [] + for i in range(10): + # Create structured pattern with known magnitude + latent = torch.zeros(16, 4, 4, dtype=torch.float32) + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5 + original_latents.append(latent.clone()) + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + # Compute original latent statistics + orig_std = torch.stack(original_latents).std().item() + + output_path = tmp_path / "test_gamma_b.safetensors" + preprocessor.compute_all(save_path=output_path) + + # The stored latents should preserve original magnitude + stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples]) + + # Should be similar to original (within 20% due to potential batching effects) + assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, ( + f"Stored latent std {stored_latents_std:.2f} differs too much from " + f"original {orig_std:.2f}. Latent magnitude was not preserved." + ) + + print(f"\n✓ Original latent std: {orig_std:.2f}") + print(f"✓ Stored latent std: {stored_latents_std:.2f}") + + +class TestTrainingLossScale: + """Test that eigenvalues produce reasonable loss magnitudes""" + + def test_noise_magnitude_reasonable(self, tmp_path): + """Verify CDC noise has reasonable magnitude for training""" + from library.cdc_fm import GammaBDataset + + # Create CDC cache with deterministic data + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + for i in range(10): + # Create deterministic pattern + latent = torch.zeros(16, 4, 4, dtype=torch.float32) + for c in range(16): + for h in range(4): + for w in range(4): + latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1 + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + output_path = tmp_path / "test_gamma_b.safetensors" + cdc_path = preprocessor.compute_all(save_path=output_path) + + # Load and compute noise + gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu") + + # Simulate training scenario with deterministic data + batch_size = 3 + latents = torch.zeros(batch_size, 16, 4, 4) + for b in range(batch_size): + for c in range(16): + for h in range(4): + for w in range(4): + latents[b, c, h, w] = (b + c + h + w) / 24.0 + t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps + indices = [0, 5, 9] + + eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices) + noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t) + + # Check noise magnitude + noise_std = noise.std().item() + latent_std = latents.std().item() + + # Noise should be similar magnitude to input latents (within 10x) + ratio = noise_std / latent_std + assert 0.1 < ratio < 10.0, ( + f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) " + f"ratio {ratio:.2f} is too extreme. Will cause training instability." + ) + + # Simulated MSE loss should be reasonable + simulated_loss = torch.mean((noise - latents) ** 2).item() + assert simulated_loss < 100.0, ( + f"Simulated MSE loss {simulated_loss:.2f} is too high. " + f"Should be O(0.1-1.0) for stable training." + ) + + print(f"\n✓ Noise/latent ratio: {ratio:.2f}") + print(f"✓ Simulated MSE loss: {simulated_loss:.4f}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_interpolation_comparison.py b/tests/library/test_cdc_interpolation_comparison.py new file mode 100644 index 00000000..9ad71eaf --- /dev/null +++ b/tests/library/test_cdc_interpolation_comparison.py @@ -0,0 +1,164 @@ +""" +Test comparing interpolation vs pad/truncate for CDC preprocessing. + +This test quantifies the difference between the two approaches. +""" + +import numpy as np +import pytest +import torch +import torch.nn.functional as F + + +class TestInterpolationComparison: + """Compare interpolation vs pad/truncate""" + + def test_intermediate_representation_quality(self): + """Compare intermediate representation quality for CDC computation""" + # Create test latents with different sizes - deterministic + latent_small = torch.zeros(16, 4, 4) + for c in range(16): + for h in range(4): + for w in range(4): + latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0 + + latent_large = torch.zeros(16, 8, 8) + for c in range(16): + for h in range(8): + for w in range(8): + latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0 + + target_h, target_w = 6, 6 # Median size + + # Method 1: Interpolation + def interpolate_method(latent, target_h, target_w): + latent_input = latent.unsqueeze(0) # (1, C, H, W) + latent_resized = F.interpolate( + latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False + ) + # Resize back + C, H, W = latent.shape + latent_reconstructed = F.interpolate( + latent_resized, size=(H, W), mode='bilinear', align_corners=False + ) + error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item() + relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8) + return relative_error + + # Method 2: Pad/Truncate + def pad_truncate_method(latent, target_h, target_w): + C, H, W = latent.shape + latent_flat = latent.reshape(-1) + target_dim = C * target_h * target_w + current_dim = C * H * W + + if current_dim == target_dim: + latent_resized_flat = latent_flat + elif current_dim > target_dim: + # Truncate + latent_resized_flat = latent_flat[:target_dim] + else: + # Pad + latent_resized_flat = torch.zeros(target_dim) + latent_resized_flat[:current_dim] = latent_flat + + # Resize back + if current_dim == target_dim: + latent_reconstructed_flat = latent_resized_flat + elif current_dim > target_dim: + # Pad back + latent_reconstructed_flat = torch.zeros(current_dim) + latent_reconstructed_flat[:target_dim] = latent_resized_flat + else: + # Truncate back + latent_reconstructed_flat = latent_resized_flat[:current_dim] + + latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W) + error = torch.mean(torch.abs(latent_reconstructed - latent)).item() + relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8) + return relative_error + + # Compare for small latent (needs padding) + interp_error_small = interpolate_method(latent_small, target_h, target_w) + pad_error_small = pad_truncate_method(latent_small, target_h, target_w) + + # Compare for large latent (needs truncation) + interp_error_large = interpolate_method(latent_large, target_h, target_w) + truncate_error_large = pad_truncate_method(latent_large, target_h, target_w) + + print("\n" + "=" * 60) + print("Reconstruction Error Comparison") + print("=" * 60) + print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):") + print(f" Interpolation error: {interp_error_small:.6f}") + print(f" Pad/truncate error: {pad_error_small:.6f}") + if pad_error_small > 0: + print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%") + else: + print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)") + print(f" BUT the intermediate representation is corrupted with zeros!") + + print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):") + print(f" Interpolation error: {interp_error_large:.6f}") + print(f" Pad/truncate error: {truncate_error_large:.6f}") + if truncate_error_large > 0: + print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%") + + # The key insight: Reconstruction error is NOT what matters for CDC! + # What matters is the INTERMEDIATE representation quality used for geometry estimation. + # Pad/truncate may have good reconstruction, but the intermediate is corrupted. + + print("\nKey insight: For CDC, intermediate representation quality matters,") + print("not reconstruction error. Interpolation preserves spatial structure.") + + # Verify interpolation errors are reasonable + assert interp_error_small < 1.0, "Interpolation should have reasonable error" + assert interp_error_large < 1.0, "Interpolation should have reasonable error" + + def test_spatial_structure_preservation(self): + """Test that interpolation preserves spatial structure better than pad/truncate""" + # Create a latent with clear spatial pattern (gradient) + C, H, W = 16, 4, 4 + latent = torch.zeros(C, H, W) + for i in range(H): + for j in range(W): + latent[:, i, j] = i * W + j # Gradient pattern + + target_h, target_w = 6, 6 + + # Interpolation + latent_input = latent.unsqueeze(0) + latent_interp = F.interpolate( + latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False + ).squeeze(0) + + # Pad/truncate + latent_flat = latent.reshape(-1) + target_dim = C * target_h * target_w + latent_padded = torch.zeros(target_dim) + latent_padded[:len(latent_flat)] = latent_flat + latent_pad = latent_padded.reshape(C, target_h, target_w) + + # Check gradient preservation + # For interpolation, adjacent pixels should have smooth gradients + grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean() + grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean() + + # For padding, there will be abrupt changes (gradient to zero) + grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean() + grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean() + + print("\n" + "=" * 60) + print("Spatial Structure Preservation") + print("=" * 60) + print(f"\nGradient smoothness (lower is smoother):") + print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}") + print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}") + + # Padding introduces larger gradients due to abrupt zeros + assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients" + assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py new file mode 100644 index 00000000..f945a184 --- /dev/null +++ b/tests/library/test_cdc_standalone.py @@ -0,0 +1,232 @@ +""" +Standalone tests for CDC-FM integration. + +These tests focus on CDC-FM specific functionality without importing +the full training infrastructure that has problematic dependencies. +""" + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch +from safetensors.torch import save_file + +from library.cdc_fm import CDCPreprocessor, GammaBDataset + + +class TestCDCPreprocessor: + """Test CDC preprocessing functionality""" + + def test_cdc_preprocessor_basic_workflow(self, tmp_path): + """Test basic CDC preprocessing with small dataset""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + # Add 10 small latents + for i in range(10): + latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + # Compute and save + output_path = tmp_path / "test_gamma_b.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + # Verify file was created + assert Path(result_path).exists() + + # Verify structure + from safetensors import safe_open + + with safe_open(str(result_path), framework="pt", device="cpu") as f: + assert f.get_tensor("metadata/num_samples").item() == 10 + assert f.get_tensor("metadata/k_neighbors").item() == 5 + assert f.get_tensor("metadata/d_cdc").item() == 4 + + # Check first sample + eigvecs = f.get_tensor("eigenvectors/0") + eigvals = f.get_tensor("eigenvalues/0") + + assert eigvecs.shape[0] == 4 # d_cdc + assert eigvals.shape[0] == 4 # d_cdc + + def test_cdc_preprocessor_different_shapes(self, tmp_path): + """Test CDC preprocessing with variable-size latents (bucketing)""" + preprocessor = CDCPreprocessor( + k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu" + ) + + # Add 5 latents of shape (16, 4, 4) + for i in range(5): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + # Add 5 latents of different shape (16, 8, 8) + for i in range(5, 10): + latent = torch.randn(16, 8, 8, dtype=torch.float32) + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + # Compute and save + output_path = tmp_path / "test_gamma_b_multi.safetensors" + result_path = preprocessor.compute_all(save_path=output_path) + + # Verify both shape groups were processed + from safetensors import safe_open + + with safe_open(str(result_path), framework="pt", device="cpu") as f: + # Check shapes are stored + shape_0 = f.get_tensor("shapes/0") + shape_5 = f.get_tensor("shapes/5") + + assert tuple(shape_0.tolist()) == (16, 4, 4) + assert tuple(shape_5.tolist()) == (16, 8, 8) + + +class TestGammaBDataset: + """Test GammaBDataset loading and retrieval""" + + @pytest.fixture + def sample_cdc_cache(self, tmp_path): + """Create a sample CDC cache file for testing""" + cache_path = tmp_path / "test_gamma_b.safetensors" + + # Create mock Γ_b data for 5 samples + tensors = { + "metadata/num_samples": torch.tensor([5]), + "metadata/k_neighbors": torch.tensor([10]), + "metadata/d_cdc": torch.tensor([4]), + "metadata/gamma": torch.tensor([1.0]), + } + + # Add shape and CDC data for each sample + for i in range(5): + tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W + tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d + tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive + + save_file(tensors, str(cache_path)) + return cache_path + + def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache): + """Test that GammaBDataset loads metadata correctly""" + gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + + assert gamma_b_dataset.num_samples == 5 + assert gamma_b_dataset.d_cdc == 4 + + def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache): + """Test retrieving Γ_b^(1/2) components""" + gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + + # Get Γ_b for indices [0, 2, 4] + indices = [0, 2, 4] + eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu") + + # Check shapes + assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d) + assert eigenvalues.shape == (3, 4) # (batch, d_cdc) + + # Check values are positive + assert torch.all(eigenvalues > 0) + + def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache): + """Test compute_sigma_t_x returns x unchanged at t=0""" + gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + + # Create test latents (batch of 3, matching d=1024 flattened) + x = torch.randn(3, 1024) # B, d (flattened) + t = torch.zeros(3) # t = 0 for all samples + + # Get Γ_b components + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu") + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # At t=0, should return x unchanged + assert torch.allclose(sigma_t_x, x, atol=1e-6) + + def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache): + """Test compute_sigma_t_x returns correct shape""" + gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + + x = torch.randn(2, 1024) # B, d (flattened) + t = torch.tensor([0.3, 0.7]) + + # Get Γ_b components + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu") + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # Should return same shape as input + assert sigma_t_x.shape == x.shape + + def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache): + """Test compute_sigma_t_x produces finite values""" + gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu") + + x = torch.randn(3, 1024) # B, d (flattened) + t = torch.rand(3) # Random timesteps in [0, 1] + + # Get Γ_b components + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu") + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # Should not contain NaNs or Infs + assert not torch.isnan(sigma_t_x).any() + assert torch.isfinite(sigma_t_x).all() + + +class TestCDCEndToEnd: + """End-to-end CDC workflow tests""" + + def test_full_preprocessing_and_usage_workflow(self, tmp_path): + """Test complete workflow: preprocess -> save -> load -> use""" + # Step 1: Preprocess latents + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu" + ) + + num_samples = 10 + for i in range(num_samples): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape) + + output_path = tmp_path / "cdc_gamma_b.safetensors" + cdc_path = preprocessor.compute_all(save_path=output_path) + + # Step 2: Load with GammaBDataset + gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu") + + assert gamma_b_dataset.num_samples == num_samples + + # Step 3: Use in mock training scenario + batch_size = 3 + batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) + batch_t = torch.rand(batch_size) + batch_indices = [0, 5, 9] + + # Get Γ_b components + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu") + + # Compute geometry-aware noise + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) + + # Verify output is reasonable + assert sigma_t_x.shape == batch_latents_flat.shape + assert not torch.isnan(sigma_t_x).any() + assert torch.isfinite(sigma_t_x).all() + + # Verify that noise changes with different timesteps + sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size)) + sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size)) + + # At t=0, should be close to x; at t=1, should be different + assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6) + assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/train_network.py b/train_network.py index 6cebf5fc..be0e1601 100644 --- a/train_network.py +++ b/train_network.py @@ -622,6 +622,23 @@ class NetworkTrainer: accelerator.wait_for_everyone() + # CDC-FM preprocessing + if hasattr(args, "use_cdc_fm") and args.use_cdc_fm: + logger.info("CDC-FM enabled, preprocessing Γ_b matrices...") + cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors") + + self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b( + cdc_output_path=cdc_output_path, + k_neighbors=args.cdc_k_neighbors, + k_bandwidth=args.cdc_k_bandwidth, + d_cdc=args.cdc_d_cdc, + gamma=args.cdc_gamma, + force_recache=args.force_recache_cdc, + accelerator=accelerator, + ) + else: + self.cdc_cache_path = None + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu text_encoding_strategy = self.get_text_encoding_strategy(args) @@ -634,7 +651,7 @@ class NetworkTrainer: if val_dataset_group is not None: self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) - if unet is None: + if unet is none: # lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders) @@ -643,10 +660,10 @@ class NetworkTrainer: accelerator.print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) - if args.base_weights is not None: + if args.base_weights is not none: # base_weights が指定されている場合は、指定された重みを読み込みマージする for i, weight_path in enumerate(args.base_weights): - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i: multiplier = 1.0 else: multiplier = args.base_weights_multiplier[i] @@ -660,6 +677,17 @@ class NetworkTrainer: accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + # Load CDC-FM Γ_b dataset if enabled + if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None: + from library.cdc_fm import GammaBDataset + + logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}") + self.gamma_b_dataset = GammaBDataset( + gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu" + ) + else: + self.gamma_b_dataset = None + # prepare network net_kwargs = {} if args.network_args is not None: