This commit is contained in:
Dave Lage
2025-10-18 18:36:23 +00:00
committed by GitHub
11 changed files with 2198 additions and 11 deletions

View File

@@ -43,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 faiss-cpu==1.12.0
pip install -r requirements.txt
- name: Test with pytest

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
benchmark_*.py

View File

@@ -1,7 +1,5 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
@@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
def assert_extra_args(
self,
@@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
# Get CDC parameters if enabled
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None
# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
)
# pack latents and get img_ids
@@ -456,6 +461,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
# CDC-FM metadata
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
@@ -494,7 +508,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
logger.info("T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
@@ -533,6 +547,72 @@ def setup_parser() -> argparse.ArgumentParser:
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
# CDC-FM arguments
parser.add_argument(
"--use_cdc_fm",
action="store_true",
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
" / CDC-FMCarré du Champ Flow Matchingを有効にして幾何学的イズを使用",
)
parser.add_argument(
"--cdc_k_neighbors",
type=int,
default=256,
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
" / CDC-FMのk-NNグラフの近傍数デフォルト: 256",
)
parser.add_argument(
"--cdc_k_bandwidth",
type=int,
default=8,
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
" / CDC-FMの帯域幅推定の近傍数デフォルト: 8",
)
parser.add_argument(
"--cdc_d_cdc",
type=int,
default=8,
help="Dimension of CDC subspace (default: 8)"
" / CDCサブ空間の次元デフォルト: 8",
)
parser.add_argument(
"--cdc_gamma",
type=float,
default=1.0,
help="CDC strength parameter (default: 1.0)"
" / CDC強度パラメータデフォルト: 1.0",
)
parser.add_argument(
"--force_recache_cdc",
action="store_true",
help="Force recompute CDC cache even if valid cache exists"
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
)
parser.add_argument(
"--cdc_debug",
action="store_true",
help="Enable verbose CDC debug output showing bucket details"
" / CDCの詳細デバッグ出力を有効化バケット詳細表示",
)
parser.add_argument(
"--cdc_adaptive_k",
action="store_true",
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
)
parser.add_argument(
"--cdc_min_bucket_size",
type=int,
default=16,
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスイズを使用。"
"--cdc_adaptive_k有効時のみ関連デフォルト: 16",
)
return parser

867
library/cdc_fm.py Normal file
View File

@@ -0,0 +1,867 @@
import logging
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class LatentSample:
"""
Container for a single latent with metadata
"""
latent: np.ndarray # (d,) flattened latent vector
global_idx: int # Global index in dataset
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
latents_npz_path: str # Path to the latent cache file
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
class CarreDuChampComputer:
"""
Core CDC-FM computation - agnostic to data source
Just handles the math for a batch of same-size latents
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda'
):
self.k = k_neighbors
self.k_bw = k_bandwidth
self.d_cdc = d_cdc
self.gamma = gamma
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using pure PyTorch
Args:
latents_np: (N, d) numpy array of same-dimensional latents
Returns:
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
indices: (N, k_actual+1) neighbor indices
"""
N, d = latents_np.shape
# Clamp k to available neighbors (can't have more neighbors than samples)
k_actual = min(self.k, N - 1)
# Convert to torch tensor
latents_tensor = torch.from_numpy(latents_np).to(self.device)
# Compute pairwise L2 distances efficiently
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a, b>
# This is more memory efficient than computing all pairwise differences
# For large batches, we'll chunk the computation
chunk_size = 1000 # Process 1000 queries at a time to manage memory
if N <= chunk_size:
# Small batch: compute all at once
distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances = torch.sqrt(distances_k_sq).cpu().numpy()
indices = indices_k.cpu().numpy()
else:
# Large batch: chunk to avoid OOM
distances_list = []
indices_list = []
for i in range(0, N, chunk_size):
end_i = min(i + chunk_size, N)
chunk = latents_tensor[i:end_i]
# Compute distances for this chunk
distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy())
indices_list.append(indices_k.cpu().numpy())
# Free memory
del distances_sq, distances_k_sq, indices_k
if torch.cuda.is_available():
torch.cuda.empty_cache()
distances = np.concatenate(distances_list, axis=0)
indices = np.concatenate(indices_list, axis=0)
return distances, indices
@torch.no_grad()
def compute_gamma_b_single(
self,
point_idx: int,
latents_np: np.ndarray,
distances: np.ndarray,
indices: np.ndarray,
epsilon: np.ndarray
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute Γ_b for a single point
Args:
point_idx: Index of point to process
latents_np: (N, d) all latents in this batch
distances: (N, k+1) precomputed distances
indices: (N, k+1) precomputed neighbor indices
epsilon: (N,) bandwidth per point
Returns:
eigenvectors: (d_cdc, d) as half precision tensor
eigenvalues: (d_cdc,) as half precision tensor
"""
d = latents_np.shape[1]
# Get neighbors (exclude self)
neighbor_idx = indices[point_idx, 1:] # (k,)
neighbor_points = latents_np[neighbor_idx] # (k, d)
# Clamp distances to prevent overflow (max realistic L2 distance)
MAX_DISTANCE = 1e10
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
# Compute Gaussian kernel weights with numerical guards
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
# Compute denominator with guard against overflow
denom = eps_i * eps_neighbors
denom = np.maximum(denom, 1e-20) # Additional guard
# Compute weights with safe exponential
exp_arg = -neighbor_dists_sq / denom
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
weights = np.exp(exp_arg)
# Normalize weights, handle edge case of all zeros
weight_sum = weights.sum()
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
# Fallback to uniform weights
weights = np.ones_like(weights) / len(weights)
else:
weights = weights / weight_sum
# Compute local mean
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
# Center and weight for SVD
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
# Validate input is finite before SVD
if not np.all(np.isfinite(weighted_centered)):
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
# Fallback: use uniform weights and simple centering
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
m_star = np.mean(neighbor_points, axis=0)
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
# Move to GPU for SVD
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
self.device, dtype=torch.float32
)
try:
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
except RuntimeError as e:
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
try:
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
U = torch.from_numpy(U).to(self.device)
S = torch.from_numpy(S).to(self.device)
Vh = torch.from_numpy(Vh).to(self.device)
except np.linalg.LinAlgError as e2:
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
# Return zero eigenvalues/vectors as fallback
return (
torch.zeros(self.d_cdc, d, dtype=torch.float16),
torch.zeros(self.d_cdc, dtype=torch.float16)
)
# Eigenvalues of Γ_b
eigenvalues_full = S ** 2
# Keep top d_cdc
if len(eigenvalues_full) >= self.d_cdc:
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
else:
# Pad if k < d_cdc
top_eigenvalues = eigenvalues_full
top_eigenvectors = Vh
if len(eigenvalues_full) < self.d_cdc:
pad_size = self.d_cdc - len(eigenvalues_full)
top_eigenvalues = torch.cat([
top_eigenvalues,
torch.zeros(pad_size, device=self.device)
])
top_eigenvectors = torch.cat([
top_eigenvectors,
torch.zeros(pad_size, d, device=self.device)
])
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
# Then apply gamma: γc_i Γ̂(x^(i))
#
# Our implementation:
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
# 3. Clamp for numerical stability
#
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
# Step 1: Normalize by the maximum eigenvalue to get relative scales
# This is the paper's 1/λ_1^i normalization factor
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
if max_eigenval > 1e-10:
# Scale so max eigenvalue = 1.0, preserving relative ratios
top_eigenvalues = top_eigenvalues / max_eigenval
# Step 2: Apply gamma and clamp to safe range
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
# Clamping ensures numerical stability and prevents extreme values
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
# fp16 range: 6e-5 to 65,504, our values are well within this
eigenvectors_fp16 = top_eigenvectors.cpu().half()
eigenvalues_fp16 = top_eigenvalues.cpu().half()
# Cleanup
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
if torch.cuda.is_available():
torch.cuda.empty_cache()
return eigenvectors_fp16, eigenvalues_fp16
def compute_for_batch(
self,
latents_np: np.ndarray,
global_indices: List[int]
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute Γ_b for all points in a batch of same-size latents
Args:
latents_np: (N, d) numpy array
global_indices: List of global dataset indices for each latent
Returns:
Dict mapping global_idx -> (eigenvectors, eigenvalues)
"""
N, d = latents_np.shape
# Validate inputs
if len(global_indices) != N:
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
print(f"Computing CDC for batch: {N} samples, dim={d}")
# Handle small sample cases - require minimum samples for meaningful k-NN
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
if N < MIN_SAMPLES_FOR_CDC:
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
results = {}
for local_idx in range(N):
global_idx = global_indices[local_idx]
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
results[global_idx] = (eigvecs, eigvals)
return results
# Step 1: Build k-NN graph
print(" Building k-NN graph...")
distances, indices = self.compute_knn_graph(latents_np)
# Step 2: Compute bandwidth
# Use min to handle case where k_bw >= actual neighbors returned
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
epsilon = distances[:, k_bw_actual]
# Step 3: Compute Γ_b for each point
results = {}
print(" Computing Γ_b for each point...")
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
global_idx = global_indices[local_idx]
eigvecs, eigvals = self.compute_gamma_b_single(
local_idx, latents_np, distances, indices, epsilon
)
results[global_idx] = (eigvecs, eigvals)
return results
class LatentBatcher:
"""
Collects variable-size latents and batches them by size
"""
def __init__(self, size_tolerance: float = 0.0):
"""
Args:
size_tolerance: If > 0, group latents within tolerance % of size
If 0, only exact size matches are batched
"""
self.size_tolerance = size_tolerance
self.samples: List[LatentSample] = []
def add_sample(self, sample: LatentSample):
"""Add a single latent sample"""
self.samples.append(sample)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a latent vector with automatic shape tracking
Args:
latent: Latent vector (any shape, will be flattened)
global_idx: Global index in dataset
latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz")
shape: Original shape (if None, uses latent.shape)
metadata: Optional metadata dict
"""
# Convert to numpy and flatten
if isinstance(latent, torch.Tensor):
latent_np = latent.cpu().numpy()
else:
latent_np = latent
original_shape = shape if shape is not None else latent_np.shape
latent_flat = latent_np.flatten()
sample = LatentSample(
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
latents_npz_path=latents_npz_path,
metadata=metadata
)
self.add_sample(sample)
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
"""
Group samples by exact shape to avoid resizing distortion.
Each bucket contains only samples with identical latent dimensions.
Buckets with fewer than k_neighbors samples will be skipped during CDC
computation and fall back to standard Gaussian noise.
Returns:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
batches[shape_key] = []
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
"""
Get aspect ratio category for grouping.
Groups images by aspect ratio bins to ensure sufficient samples.
For shape (C, H, W), computes aspect ratio H/W and bins it.
"""
if len(shape) < 3:
return "unknown"
# Extract spatial dimensions (H, W)
h, w = shape[-2], shape[-1]
aspect_ratio = h / w
# Define aspect ratio bins (±15% tolerance)
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
bins = [
(0.5, 0.65, "9:16"), # Portrait tall
(0.65, 0.85, "3:4"), # Portrait
(0.85, 1.15, "1:1"), # Square
(1.15, 1.50, "4:3"), # Landscape
(1.50, 2.0, "16:9"), # Landscape wide
(2.0, 3.0, "21:9"), # Ultra wide
]
for min_ratio, max_ratio, label in bins:
if min_ratio <= aspect_ratio < max_ratio:
return label
# Fallback for extreme ratios
if aspect_ratio < 0.5:
return "ultra_tall"
else:
return "ultra_wide"
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
"""Check if two shapes are within tolerance"""
if len(shape1) != len(shape2):
return False
size1 = np.prod(shape1)
size2 = np.prod(shape2)
ratio = abs(size1 - size2) / max(size1, size2)
return ratio <= self.size_tolerance
def __len__(self):
return len(self.samples)
class CDCPreprocessor:
"""
High-level CDC preprocessing coordinator
Handles variable-size latents by batching and delegating to CarreDuChampComputer
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
dataset_dirs: Optional[List[str]] = None
):
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device=device
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
self.dataset_dirs = dataset_dirs or []
self.config_hash = self._compute_config_hash()
def _compute_config_hash(self) -> str:
"""
Compute a short hash of CDC configuration for filename uniqueness.
Hash includes:
- Sorted dataset/subset directory paths
- CDC parameters (k_neighbors, d_cdc, gamma)
This ensures CDC files are invalidated when:
- Dataset composition changes (different dirs)
- CDC parameters change
Returns:
8-character hex hash
"""
import hashlib
# Sort dataset dirs for consistent hashing
dirs_str = "|".join(sorted(self.dataset_dirs))
# Include CDC parameters
config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}"
# Create short hash (8 chars is enough for uniqueness in this context)
hash_obj = hashlib.sha256(config_str.encode())
return hash_obj.hexdigest()[:8]
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a single latent to the preprocessing queue
Args:
latent: Latent vector (will be flattened)
global_idx: Global dataset index
latents_npz_path: Path to the latent cache file
shape: Original shape (C, H, W)
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
@staticmethod
def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str:
"""
Get CDC cache path from latents cache path
Includes optional config_hash to ensure CDC files are unique to dataset/subset
configuration and CDC parameters. This prevents using stale CDC files when
the dataset composition or CDC settings change.
Args:
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
If None, returns path without hash (for backward compatibility)
Returns:
CDC cache path:
- With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
- Without: "image_0512x0768_flux_cdc.npz"
"""
path = Path(latents_npz_path)
if config_hash:
return str(path.with_stem(f"{path.stem}_cdc_{config_hash}"))
else:
return str(path.with_stem(f"{path.stem}_cdc"))
def compute_all(self) -> int:
"""
Compute Γ_b for all added latents and save individual CDC files next to each latent cache
Returns:
Number of CDC files saved
"""
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
if self.adaptive_k:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
else:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
if self.adaptive_k:
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
mode = "adaptive" if self.adaptive_k else "fixed"
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
# Process each bucket with progress bar
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
for shape, samples in bucket_iter:
num_samples = len(samples)
if self.debug:
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Determine effective k for this bucket
if self.adaptive_k:
# Adaptive mode: skip if below minimum, otherwise use best available k
if num_samples < min_threshold:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
# Use adaptive k for this bucket
k_effective = min(k_neighbors, num_samples - 1)
else:
# Fixed mode: skip if below k_neighbors
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
k_effective = k_neighbors
# Collect latents (no resizing needed - all same shape)
latents_list = []
global_indices = []
for sample in samples:
global_indices.append(sample.global_idx)
latents_list.append(sample.latent) # Already flattened
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch with effective k
if self.debug:
if self.adaptive_k and k_effective < k_neighbors:
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
else:
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
# Temporarily override k for this bucket
original_k = self.computer.k
self.computer.k = k_effective
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
self.computer.k = original_k
# No resizing needed - eigenvectors are already correct size
if self.debug:
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
# Merge into overall results
all_results.update(batch_results)
# Save individual CDC files next to each latent cache
if self.debug:
print(f"\n{'='*60}")
print("Saving individual CDC files...")
print(f"{'='*60}")
files_saved = 0
total_size = 0
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
for sample in save_iter:
# Get CDC cache path with config hash
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash)
# Get CDC results for this sample
if sample.global_idx in all_results:
eigvecs, eigvals = all_results[sample.global_idx]
# Convert to numpy if needed
if isinstance(eigvecs, torch.Tensor):
eigvecs = eigvecs.numpy()
if isinstance(eigvals, torch.Tensor):
eigvals = eigvals.numpy()
# Save metadata and CDC results
np.savez(
cdc_path,
eigenvectors=eigvecs,
eigenvalues=eigvals,
shape=np.array(sample.shape),
k_neighbors=self.computer.k,
d_cdc=self.computer.d_cdc,
gamma=self.computer.gamma
)
files_saved += 1
total_size += Path(cdc_path).stat().st_size
logger.debug(f"Saved CDC file: {cdc_path}")
total_size_mb = total_size / 1024 / 1024
logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB")
return files_saved
class GammaBDataset:
"""
Efficient loader for Γ_b matrices during training
Loads from individual CDC cache files next to latent caches
"""
def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None):
"""
Initialize CDC dataset loader
Args:
device: Device for loading tensors
config_hash: Optional config hash to use for CDC file lookup.
If None, uses default naming without hash.
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.config_hash = config_hash
if config_hash:
logger.info(f"CDC loader initialized (hash: {config_hash})")
else:
logger.info("CDC loader initialized (no hash, backward compatibility mode)")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
latents_npz_paths: List[str],
device: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of latents
Args:
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
device: Device to load to (defaults to self.device)
Returns:
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
eigenvalues: (B, d_cdc)
"""
if device is None:
device = self.device
eigenvectors_list = []
eigenvalues_list = []
for latents_npz_path in latents_npz_paths:
# Get CDC cache path with config hash
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash)
# Load CDC data
if not Path(cdc_path).exists():
raise FileNotFoundError(
f"CDC cache file not found: {cdc_path}. "
f"Make sure to run CDC preprocessing before training."
)
data = np.load(cdc_path)
eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float()
eigvals = torch.from_numpy(data['eigenvalues']).to(device).float()
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
# Check if all eigenvectors have the same dimension
dims = [ev.shape[1] for ev in eigenvectors_list]
if len(set(dims)) > 1:
# Dimension mismatch! This shouldn't happen with proper bucketing
# but can occur if batch contains mixed sizes
raise RuntimeError(
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
f"Latent paths: {latents_npz_paths}. "
f"This means the training batch contains images of different sizes, "
f"which violates CDC's requirement for uniform latent dimensions per batch. "
f"Check that your dataloader buckets are configured correctly."
)
eigenvectors = torch.stack(eigenvectors_list, dim=0)
eigenvalues = torch.stack(eigenvalues_list, dim=0)
return eigenvectors, eigenvalues
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
Args:
eigenvectors: (B, d_cdc, d)
eigenvalues: (B, d_cdc)
x: (B, d) or (B, C, H, W) - will be flattened if needed
t: (B,) or scalar time
Returns:
result: Same shape as input x
Note:
Gradients flow through this function for backprop during training.
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
if t.dim() == 0:
t = t.expand(x.shape[0])
t = t.view(-1, 1)
# Early return for t=0 to avoid numerical errors
if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8):
return x.reshape(orig_shape)
# Check if CDC is disabled (all eigenvalues are zero)
# This happens for buckets with < k_neighbors samples
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
# Fallback to standard Gaussian noise (no CDC correction)
return x.reshape(orig_shape)
# Γ_b^(1/2) @ x using low-rank representation
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Σ_t @ x
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result

View File

@@ -2,10 +2,8 @@ import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
from accelerate import Accelerator, PartialState
@@ -183,7 +181,7 @@ def sample_image_inference(
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info("negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
@@ -468,9 +466,76 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
# Global set to track samples that have already been warned about shape mismatches
# This prevents log spam during training (warning once per sample is sufficient)
_cdc_warned_samples = set()
def apply_cdc_noise_transformation(
noise: torch.Tensor,
timesteps: torch.Tensor,
num_timesteps: int,
gamma_b_dataset,
latents_npz_paths,
device
) -> torch.Tensor:
"""
Apply CDC-FM geometry-aware noise transformation.
Args:
noise: (B, C, H, W) standard Gaussian noise
timesteps: (B,) timesteps for this batch
num_timesteps: Total number of timesteps in scheduler
gamma_b_dataset: GammaBDataset with cached CDC matrices
latents_npz_paths: List of latent cache paths for this batch
device: Device to load CDC matrices to
Returns:
Transformed noise with geometry-aware covariance
"""
# Device consistency validation
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
target_device = torch.device(device) if not isinstance(device, torch.device) else device
noise_device = noise.device
# Check if devices are compatible (cuda:0 vs cuda should not warn)
devices_compatible = (
noise_device == target_device or
(noise_device.type == "cuda" and target_device.type == "cuda") or
(noise_device.type == "cpu" and target_device.type == "cpu")
)
if not devices_compatible:
logger.warning(
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
f"Transferring noise to {target_device} to avoid errors."
)
noise = noise.to(target_device)
device = target_device
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps.to(device) / num_timesteps
B, C, H, W = noise.shape
# Batch processing: Get CDC data for all samples at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
gamma_b_dataset=None, latents_npz_paths=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get noisy model input and timesteps for training.
Args:
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
@@ -514,6 +579,17 @@ def get_noisy_model_input_and_timesteps(
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and latents_npz_paths is not None:
noise = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=num_timesteps,
gamma_b_dataset=gamma_b_dataset,
latents_npz_paths=latents_npz_paths,
device=device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:

View File

@@ -40,6 +40,8 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from library.cdc_fm import CDCPreprocessor
from diffusers.optimization import (
SchedulerType as DiffusersSchedulerType,
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
@@ -1569,11 +1571,17 @@ class BaseDataset(torch.utils.data.Dataset):
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
custom_attributes = []
image_keys = [] # CDC-FM: track image keys for CDC lookup
latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
# CDC-FM: Store image_key and latents_npz path for CDC lookup
image_keys.append(image_key)
latents_npz_paths.append(image_info.latents_npz)
custom_attributes.append(subset.custom_attributes)
# in case of fine tuning, is_reg is always False
@@ -1819,6 +1827,9 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
# CDC-FM: Add latents_npz paths to batch for CDC lookup
example["latents_npz"] = latents_npz_paths
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -2690,6 +2701,165 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def cache_cdc_gamma_b(
self,
cdc_output_path: str,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
) -> Optional[str]:
"""
Cache CDC Γ_b matrices for all latents in the dataset
CDC files are saved as individual .npz files next to each latent cache file.
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc.npz
Args:
cdc_output_path: Deprecated (CDC uses per-file caching now)
k_neighbors: k-NN neighbors
k_bandwidth: Bandwidth estimation neighbors
d_cdc: CDC subspace dimension
gamma: CDC strength
force_recache: Force recompute even if cache exists
accelerator: For multi-GPU support
Returns:
"per_file" to indicate per-file caching is used, or None on error
"""
from pathlib import Path
# Collect dataset/subset directories for config hash
dataset_dirs = []
for dataset in self.datasets:
# Get the directory containing the images
if hasattr(dataset, 'image_dir'):
dataset_dirs.append(str(dataset.image_dir))
# Fallback: use first image's parent directory
elif dataset.image_data:
first_image = next(iter(dataset.image_data.values()))
dataset_dirs.append(str(Path(first_image.absolute_path).parent))
# Create preprocessor to get config hash
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device="cuda" if torch.cuda.is_available() else "cpu",
debug=debug,
adaptive_k=adaptive_k,
min_bucket_size=min_bucket_size,
dataset_dirs=dataset_dirs
)
logger.info(f"CDC config hash: {preprocessor.config_hash}")
# Check if CDC caches already exist (unless force_recache)
if not force_recache:
all_cached = self._check_cdc_caches_exist(preprocessor.config_hash)
if all_cached:
logger.info("All CDC cache files found, skipping preprocessing")
return preprocessor.config_hash
else:
logger.info("Some CDC cache files missing, will compute")
# Only main process computes CDC
is_main = accelerator is None or accelerator.is_main_process
if not is_main:
if accelerator is not None:
accelerator.wait_for_everyone()
return preprocessor.config_hash
logger.info("Starting CDC-FM preprocessing")
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
# Get caching strategy for loading latents
from library.strategy_base import LatentsCachingStrategy
caching_strategy = LatentsCachingStrategy.get_strategy()
# Collect all latents from all datasets
for dataset_idx, dataset in enumerate(self.datasets):
logger.info(f"Loading latents from dataset {dataset_idx}...")
image_infos = list(dataset.image_data.values())
for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")):
# Load latent from disk or memory
if info.latents is not None:
latent = info.latents
elif info.latents_npz is not None:
# Load from disk
latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso)
if latent is None:
logger.warning(f"Failed to load latent from {info.latents_npz}, skipping")
continue
else:
logger.warning(f"No latent found for {info.absolute_path}, skipping")
continue
# Add to preprocessor (with unique global index across all datasets)
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
# Get latents_npz_path - will be set whether caching to disk or memory
if info.latents_npz is None:
# If not set, generate the path from the caching strategy
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso)
preprocessor.add_latent(
latent=latent,
global_idx=actual_global_idx,
latents_npz_path=info.latents_npz,
shape=latent.shape,
metadata={"image_key": info.image_key}
)
# Compute and save individual CDC files
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
files_saved = preprocessor.compute_all()
logger.info(f"Saved {files_saved} CDC cache files")
if accelerator is not None:
accelerator.wait_for_everyone()
# Return config hash so training can initialize GammaBDataset with it
return preprocessor.config_hash
def _check_cdc_caches_exist(self, config_hash: str) -> bool:
"""
Check if CDC cache files exist for all latents in the dataset
Args:
config_hash: The config hash to use for CDC filename lookup
"""
from pathlib import Path
missing_count = 0
total_count = 0
for dataset in self.datasets:
for info in dataset.image_data.values():
total_count += 1
if info.latents_npz is None:
# If latents_npz not set, we can't check for CDC cache
continue
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash)
if not Path(cdc_path).exists():
missing_count += 1
if missing_count > 0:
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")
return False
logger.debug(f"All {total_count} CDC cache files exist")
return True
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)

View File

@@ -0,0 +1,183 @@
import torch
from typing import Union
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCAdvanced:
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_gradient_flow_preservation(self):
"""
Verify that gradient flow is preserved even for near-zero time steps
with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
# Optional: Print gradient details for debugging
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
def test_gradient_flow_with_different_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
if t.grad is not None:
t.grad.zero_()
if latent.grad is not None:
latent.grad.zero_()
def pytest_configure(config):
"""
Add custom markers for CDC-FM tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)

View File

@@ -0,0 +1,157 @@
"""
Test CDC config hash generation and cache invalidation
"""
import pytest
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor
class TestCDCConfigHash:
"""
Test that CDC config hash properly invalidates cache when dataset or parameters change
"""
def test_same_config_produces_same_hash(self, tmp_path):
"""
Test that identical configurations produce identical hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_different_dataset_dirs_produce_different_hash(self, tmp_path):
"""
Test that different dataset directories produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset2")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_k_neighbors_produces_different_hash(self, tmp_path):
"""
Test that different k_neighbors values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_d_cdc_produces_different_hash(self, tmp_path):
"""
Test that different d_cdc values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_gamma_produces_different_hash(self, tmp_path):
"""
Test that different gamma values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_multiple_dataset_dirs_order_independent(self, tmp_path):
"""
Test that dataset directory order doesn't affect hash (they are sorted)
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_hash_length_is_8_chars(self, tmp_path):
"""
Test that hash is exactly 8 characters (hex)
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert len(preprocessor.config_hash) == 8
# Verify it's hex
int(preprocessor.config_hash, 16) # Should not raise
def test_filename_includes_hash(self, tmp_path):
"""
Test that CDC filenames include the config hash
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash)
# Should be: image_0512x0768_flux_cdc_<hash>.npz
expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz")
assert cdc_path == expected
def test_backward_compatibility_no_hash(self, tmp_path):
"""
Test that get_cdc_npz_path works without hash (backward compatibility)
"""
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None)
# Should be: image_0512x0768_flux_cdc.npz (no hash suffix)
expected = str(tmp_path / "image_0512x0768_flux_cdc.npz")
assert cdc_path == expected
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,320 @@
"""
CDC Preprocessor and Device Consistency Tests
This module provides testing of:
1. CDC Preprocessor functionality
2. Device consistency handling
3. GammaBDataset loading and usage
4. End-to-end CDC workflow verification
"""
import pytest
import logging
import torch
from pathlib import Path
from safetensors.torch import save_file
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestCDCPreprocessorIntegration:
"""
Comprehensive testing of CDC preprocessing and device handling
"""
def test_basic_preprocessor_workflow(self, tmp_path):
"""
Test basic CDC preprocessing with small dataset
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure (with config hash)
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
assert cdc_path.exists()
import numpy as np
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_preprocessor_with_different_shapes(self, tmp_path):
"""
Test CDC preprocessing with variable-size latents (bucketing)
"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
assert files_saved == 10
import numpy as np
# Check shapes are stored in individual files (with config hash)
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestDeviceConsistency:
"""
Test device handling and consistency for CDC transformations
"""
def test_matching_devices_no_warning(self, tmp_path, caplog):
"""
Test that no warnings are emitted when devices match.
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
latents_npz_paths_batch = latents_npz_paths[:2]
with caplog.at_level(logging.WARNING):
caplog.clear()
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
# No device mismatch warnings
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
assert len(device_warnings) == 0, "Should not warn when devices match"
def test_device_mismatch_handling(self, tmp_path):
"""
Test that CDC transformation handles device mismatch gracefully
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Create noise and timesteps
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
latents_npz_paths_batch = latents_npz_paths[:2]
# Perform CDC transformation
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
# Verify output characteristics
assert result.shape == noise.shape
assert result.device == noise.device
assert result.requires_grad # Gradients should still work
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()
# Verify gradients flow
loss = result.sum()
loss.backward()
assert noise.grad is not None
class TestCDCEndToEnd:
"""
End-to-end CDC workflow tests
"""
def test_full_preprocessing_usage_workflow(self, tmp_path):
"""
Test complete workflow: preprocess -> save -> load -> use
"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset (use config hash)
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
def pytest_configure(config):
"""
Configure custom markers for CDC tests
"""
config.addinivalue_line(
"markers",
"device_consistency: mark test to verify device handling in CDC transformations"
)
config.addinivalue_line(
"markers",
"preprocessor: mark test to verify CDC preprocessing workflow"
)
config.addinivalue_line(
"markers",
"end_to_end: mark test to verify full CDC workflow"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,292 @@
"""
Standalone tests for CDC-FM per-file caching.
These tests focus on the current CDC-FM per-file caching implementation
with hash-based cache validation.
"""
from pathlib import Path
import pytest
import torch
import numpy as np
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPreprocessor:
"""Test CDC preprocessing functionality with per-file caching"""
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
"""Test basic CDC preprocessing with small dataset"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save (creates per-file CDC caches)
files_saved = preprocessor.compute_all()
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
assert cdc_path.exists()
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_cdc_preprocessor_different_shapes(self, tmp_path):
"""Test CDC preprocessing with variable-size latents (bucketing)"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
assert files_saved == 10
# Check shapes are stored in individual files
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestGammaBDataset:
"""Test GammaBDataset loading and retrieval with per-file caching"""
@pytest.fixture
def sample_cdc_cache(self, tmp_path):
"""Create sample CDC cache files for testing"""
# Use 20 samples to ensure proper k-NN computation
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)],
adaptive_k=True, # Enable adaptive k for small dataset
min_bucket_size=5
)
# Create 20 samples
latents_npz_paths = []
for i in range(20):
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
preprocessor.compute_all()
return tmp_path, latents_npz_paths, preprocessor.config_hash
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
"""Test that GammaBDataset loads CDC files correctly"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get components for first sample
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu")
# Check shapes
assert eigvecs.shape[0] == 1 # batch size
assert eigvecs.shape[1] == 4 # d_cdc
assert eigvals.shape == (1, 4) # batch, d_cdc
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
"""Test retrieving Γ_b^(1/2) components"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get Γ_b for paths [0, 2, 4]
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
# Check shapes
assert eigenvectors.shape[0] == 3 # batch
assert eigenvectors.shape[1] == 4 # d_cdc
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
# Check values are positive
assert torch.all(eigenvalues > 0)
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns x unchanged at t=0"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Create test latents (batch of 3, matching d=1024 flattened)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.zeros(3) # t = 0 for all samples
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# At t=0, should return x unchanged
assert torch.allclose(sigma_t_x, x, atol=1e-6)
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns correct shape"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(2, 1024) # B, d (flattened)
t = torch.tensor([0.3, 0.7])
# Get Γ_b components
paths = [latents_npz_paths[1], latents_npz_paths[3]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should return same shape as input
assert sigma_t_x.shape == x.shape
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
"""Test compute_sigma_t_x produces finite values"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.rand(3) # Random timesteps in [0, 1]
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should not contain NaNs or Infs
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
class TestCDCEndToEnd:
"""End-to-end CDC workflow tests"""
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
"""Test complete workflow: preprocess -> save -> load -> use"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu")
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -622,6 +622,29 @@ class NetworkTrainer:
accelerator.wait_for_everyone()
# CDC-FM preprocessing
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors")
self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b(
cdc_output_path=cdc_output_path,
k_neighbors=args.cdc_k_neighbors,
k_bandwidth=args.cdc_k_bandwidth,
d_cdc=args.cdc_d_cdc,
gamma=args.cdc_gamma,
force_recache=args.force_recache_cdc,
accelerator=accelerator,
debug=getattr(args, 'cdc_debug', False),
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
)
if self.cdc_cache_path is None:
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
else:
self.cdc_cache_path = None
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
text_encoding_strategy = self.get_text_encoding_strategy(args)
@@ -660,6 +683,24 @@ class NetworkTrainer:
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# Load CDC-FM Γ_b dataset if enabled
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
from library.cdc_fm import GammaBDataset
# cdc_cache_path now contains the config hash
config_hash = self.cdc_cache_path if self.cdc_cache_path != "per_file" else None
if config_hash:
logger.info(f"CDC Γ_b dataset ready (hash: {config_hash})")
else:
logger.info("CDC Γ_b dataset ready (no hash, backward compatibility)")
self.gamma_b_dataset = GammaBDataset(
device="cuda" if torch.cuda.is_available() else "cpu",
config_hash=config_hash
)
else:
self.gamma_b_dataset = None
# prepare network
net_kwargs = {}
if args.network_args is not None: