This commit is contained in:
Dave Lage
2026-02-11 01:09:24 +01:00
committed by GitHub
14 changed files with 2856 additions and 30 deletions

View File

@@ -43,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 faiss-cpu==1.12.0
pip install -r requirements.txt
- name: Test with pytest

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
benchmark_*.py

View File

@@ -1,7 +1,5 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
@@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
def assert_extra_args(
self,
@@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
# Get CDC parameters if enabled
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None
latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None
# Get noisy model input and timesteps
# If CDC is enabled, this will transform the noise with geometry-aware covariance
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths
)
# pack latents and get img_ids
@@ -456,6 +461,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
# CDC-FM metadata
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
@@ -494,7 +508,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
logger.info("T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
@@ -533,6 +547,72 @@ def setup_parser() -> argparse.ArgumentParser:
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
# CDC-FM arguments
parser.add_argument(
"--use_cdc_fm",
action="store_true",
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
" / CDC-FMCarré du Champ Flow Matchingを有効にして幾何学的イズを使用",
)
parser.add_argument(
"--cdc_k_neighbors",
type=int,
default=256,
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
" / CDC-FMのk-NNグラフの近傍数デフォルト: 256",
)
parser.add_argument(
"--cdc_k_bandwidth",
type=int,
default=8,
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
" / CDC-FMの帯域幅推定の近傍数デフォルト: 8",
)
parser.add_argument(
"--cdc_d_cdc",
type=int,
default=8,
help="Dimension of CDC subspace (default: 8)"
" / CDCサブ空間の次元デフォルト: 8",
)
parser.add_argument(
"--cdc_gamma",
type=float,
default=1.0,
help="CDC strength parameter (default: 1.0)"
" / CDC強度パラメータデフォルト: 1.0",
)
parser.add_argument(
"--force_recache_cdc",
action="store_true",
help="Force recompute CDC cache even if valid cache exists"
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
)
parser.add_argument(
"--cdc_debug",
action="store_true",
help="Enable verbose CDC debug output showing bucket details"
" / CDCの詳細デバッグ出力を有効化バケット詳細表示",
)
parser.add_argument(
"--cdc_adaptive_k",
action="store_true",
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
)
parser.add_argument(
"--cdc_min_bucket_size",
type=int,
default=16,
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスイズを使用。"
"--cdc_adaptive_k有効時のみ関連デフォルト: 16",
)
return parser

905
library/cdc_fm.py Normal file
View File

@@ -0,0 +1,905 @@
import logging
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class LatentSample:
"""
Container for a single latent with metadata
"""
latent: np.ndarray # (d,) flattened latent vector
global_idx: int # Global index in dataset
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
latents_npz_path: str # Path to the latent cache file
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
class CarreDuChampComputer:
"""
Core CDC-FM computation - agnostic to data source
Just handles the math for a batch of same-size latents
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda'
):
self.k = k_neighbors
self.k_bw = k_bandwidth
self.d_cdc = d_cdc
self.gamma = gamma
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using pure PyTorch
Args:
latents_np: (N, d) numpy array of same-dimensional latents
Returns:
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
indices: (N, k_actual+1) neighbor indices
"""
N, d = latents_np.shape
# Clamp k to available neighbors (can't have more neighbors than samples)
k_actual = min(self.k, N - 1)
# Convert to torch tensor
latents_tensor = torch.from_numpy(latents_np).to(self.device)
# Compute pairwise L2 distances efficiently
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2<a, b>
# This is more memory efficient than computing all pairwise differences
# For large batches, we'll chunk the computation
chunk_size = 1000 # Process 1000 queries at a time to manage memory
if N <= chunk_size:
# Small batch: compute all at once
distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances = torch.sqrt(distances_k_sq).cpu().numpy()
indices = indices_k.cpu().numpy()
else:
# Large batch: chunk to avoid OOM
distances_list = []
indices_list = []
for i in range(0, N, chunk_size):
end_i = min(i + chunk_size, N)
chunk = latents_tensor[i:end_i]
# Compute distances for this chunk
distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2
distances_k_sq, indices_k = torch.topk(
distances_sq, k=k_actual + 1, dim=1, largest=False
)
distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy())
indices_list.append(indices_k.cpu().numpy())
# Free memory
del distances_sq, distances_k_sq, indices_k
if torch.cuda.is_available():
torch.cuda.empty_cache()
distances = np.concatenate(distances_list, axis=0)
indices = np.concatenate(indices_list, axis=0)
return distances, indices
@torch.no_grad()
def compute_gamma_b_single(
self,
point_idx: int,
latents_np: np.ndarray,
distances: np.ndarray,
indices: np.ndarray,
epsilon: np.ndarray
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute Γ_b for a single point
Args:
point_idx: Index of point to process
latents_np: (N, d) all latents in this batch
distances: (N, k+1) precomputed distances
indices: (N, k+1) precomputed neighbor indices
epsilon: (N,) bandwidth per point
Returns:
eigenvectors: (d_cdc, d) as half precision tensor
eigenvalues: (d_cdc,) as half precision tensor
"""
d = latents_np.shape[1]
# Get neighbors (exclude self)
neighbor_idx = indices[point_idx, 1:] # (k,)
neighbor_points = latents_np[neighbor_idx] # (k, d)
# Clamp distances to prevent overflow (max realistic L2 distance)
MAX_DISTANCE = 1e10
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
# Compute Gaussian kernel weights with numerical guards
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
# Compute denominator with guard against overflow
denom = eps_i * eps_neighbors
denom = np.maximum(denom, 1e-20) # Additional guard
# Compute weights with safe exponential
exp_arg = -neighbor_dists_sq / denom
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
weights = np.exp(exp_arg)
# Normalize weights, handle edge case of all zeros
weight_sum = weights.sum()
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
# Fallback to uniform weights
weights = np.ones_like(weights) / len(weights)
else:
weights = weights / weight_sum
# Compute local mean
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
# Center and weight for SVD
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
# Validate input is finite before SVD
if not np.all(np.isfinite(weighted_centered)):
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
# Fallback: use uniform weights and simple centering
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
m_star = np.mean(neighbor_points, axis=0)
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
# Move to GPU for SVD
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
self.device, dtype=torch.float32
)
try:
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
except RuntimeError as e:
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
try:
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
U = torch.from_numpy(U).to(self.device)
S = torch.from_numpy(S).to(self.device)
Vh = torch.from_numpy(Vh).to(self.device)
except np.linalg.LinAlgError as e2:
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
# Return zero eigenvalues/vectors as fallback
return (
torch.zeros(self.d_cdc, d, dtype=torch.float16),
torch.zeros(self.d_cdc, dtype=torch.float16)
)
# Eigenvalues of Γ_b
eigenvalues_full = S ** 2
# Keep top d_cdc
if len(eigenvalues_full) >= self.d_cdc:
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
else:
# Pad if k < d_cdc
top_eigenvalues = eigenvalues_full
top_eigenvectors = Vh
if len(eigenvalues_full) < self.d_cdc:
pad_size = self.d_cdc - len(eigenvalues_full)
top_eigenvalues = torch.cat([
top_eigenvalues,
torch.zeros(pad_size, device=self.device)
])
top_eigenvectors = torch.cat([
top_eigenvectors,
torch.zeros(pad_size, d, device=self.device)
])
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
# Then apply gamma: γc_i Γ̂(x^(i))
#
# Our implementation:
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
# 3. Clamp for numerical stability
#
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
# Step 1: Normalize by the maximum eigenvalue to get relative scales
# This is the paper's 1/λ_1^i normalization factor
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
if max_eigenval > 1e-10:
# Scale so max eigenvalue = 1.0, preserving relative ratios
top_eigenvalues = top_eigenvalues / max_eigenval
# Step 2: Apply gamma and clamp to safe range
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
# Clamping ensures numerical stability and prevents extreme values
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
# fp16 range: 6e-5 to 65,504, our values are well within this
eigenvectors_fp16 = top_eigenvectors.cpu().half()
eigenvalues_fp16 = top_eigenvalues.cpu().half()
# Cleanup
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
if torch.cuda.is_available():
torch.cuda.empty_cache()
return eigenvectors_fp16, eigenvalues_fp16
def compute_for_batch(
self,
latents_np: np.ndarray,
global_indices: List[int]
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute Γ_b for all points in a batch of same-size latents
Args:
latents_np: (N, d) numpy array
global_indices: List of global dataset indices for each latent
Returns:
Dict mapping global_idx -> (eigenvectors, eigenvalues)
"""
N, d = latents_np.shape
# Validate inputs
if len(global_indices) != N:
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
print(f"Computing CDC for batch: {N} samples, dim={d}")
# Handle small sample cases - require minimum samples for meaningful k-NN
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
if N < MIN_SAMPLES_FOR_CDC:
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
results = {}
for local_idx in range(N):
global_idx = global_indices[local_idx]
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
results[global_idx] = (eigvecs, eigvals)
return results
# Step 1: Build k-NN graph
print(" Building k-NN graph...")
distances, indices = self.compute_knn_graph(latents_np)
# Step 2: Compute bandwidth
# Use min to handle case where k_bw >= actual neighbors returned
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
epsilon = distances[:, k_bw_actual]
# Step 3: Compute Γ_b for each point
results = {}
print(" Computing Γ_b for each point...")
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
global_idx = global_indices[local_idx]
eigvecs, eigvals = self.compute_gamma_b_single(
local_idx, latents_np, distances, indices, epsilon
)
results[global_idx] = (eigvecs, eigvals)
return results
class LatentBatcher:
"""
Collects variable-size latents and batches them by size
"""
def __init__(self, size_tolerance: float = 0.0):
"""
Args:
size_tolerance: If > 0, group latents within tolerance % of size
If 0, only exact size matches are batched
"""
self.size_tolerance = size_tolerance
self.samples: List[LatentSample] = []
def add_sample(self, sample: LatentSample):
"""Add a single latent sample"""
self.samples.append(sample)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a latent vector with automatic shape tracking
Args:
latent: Latent vector (any shape, will be flattened)
global_idx: Global index in dataset
latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz")
shape: Original shape (if None, uses latent.shape)
metadata: Optional metadata dict
"""
# Convert to numpy and flatten
if isinstance(latent, torch.Tensor):
latent_np = latent.cpu().numpy()
else:
latent_np = latent
original_shape = shape if shape is not None else latent_np.shape
latent_flat = latent_np.flatten()
sample = LatentSample(
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
latents_npz_path=latents_npz_path,
metadata=metadata
)
self.add_sample(sample)
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
"""
Group samples by exact shape to avoid resizing distortion.
Each bucket contains only samples with identical latent dimensions.
Buckets with fewer than k_neighbors samples will be skipped during CDC
computation and fall back to standard Gaussian noise.
Returns:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
batches[shape_key] = []
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
"""
Get aspect ratio category for grouping.
Groups images by aspect ratio bins to ensure sufficient samples.
For shape (C, H, W), computes aspect ratio H/W and bins it.
"""
if len(shape) < 3:
return "unknown"
# Extract spatial dimensions (H, W)
h, w = shape[-2], shape[-1]
aspect_ratio = h / w
# Define aspect ratio bins (±15% tolerance)
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
bins = [
(0.5, 0.65, "9:16"), # Portrait tall
(0.65, 0.85, "3:4"), # Portrait
(0.85, 1.15, "1:1"), # Square
(1.15, 1.50, "4:3"), # Landscape
(1.50, 2.0, "16:9"), # Landscape wide
(2.0, 3.0, "21:9"), # Ultra wide
]
for min_ratio, max_ratio, label in bins:
if min_ratio <= aspect_ratio < max_ratio:
return label
# Fallback for extreme ratios
if aspect_ratio < 0.5:
return "ultra_tall"
else:
return "ultra_wide"
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
"""Check if two shapes are within tolerance"""
if len(shape1) != len(shape2):
return False
size1 = np.prod(shape1)
size2 = np.prod(shape2)
ratio = abs(size1 - size2) / max(size1, size2)
return ratio <= self.size_tolerance
def __len__(self):
return len(self.samples)
class CDCPreprocessor:
"""
High-level CDC preprocessing coordinator
Handles variable-size latents by batching and delegating to CarreDuChampComputer
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
dataset_dirs: Optional[List[str]] = None
):
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device=device
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
self.dataset_dirs = dataset_dirs or []
self.config_hash = self._compute_config_hash()
def _compute_config_hash(self) -> str:
"""
Compute a short hash of CDC configuration for filename uniqueness.
Hash includes:
- Sorted dataset/subset directory paths
- CDC parameters (k_neighbors, d_cdc, gamma)
This ensures CDC files are invalidated when:
- Dataset composition changes (different dirs)
- CDC parameters change
Returns:
8-character hex hash
"""
import hashlib
# Sort dataset dirs for consistent hashing
dirs_str = "|".join(sorted(self.dataset_dirs))
# Include CDC parameters
config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}"
# Create short hash (8 chars is enough for uniqueness in this context)
hash_obj = hashlib.sha256(config_str.encode())
return hash_obj.hexdigest()[:8]
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
latents_npz_path: str,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a single latent to the preprocessing queue
Args:
latent: Latent vector (will be flattened)
global_idx: Global dataset index
latents_npz_path: Path to the latent cache file
shape: Original shape (C, H, W)
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
@staticmethod
def get_cdc_npz_path(
latents_npz_path: str,
config_hash: Optional[str] = None,
latent_shape: Optional[Tuple[int, ...]] = None
) -> str:
"""
Get CDC cache path from latents cache path
Includes optional config_hash to ensure CDC files are unique to dataset/subset
configuration and CDC parameters. This prevents using stale CDC files when
the dataset composition or CDC settings change.
IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure
CDC files are unique per resolution. Without it, different resolutions will overwrite
each other's CDC caches, causing dimension mismatch errors.
Args:
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
If None, returns path without hash (for backward compatibility)
latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific
For multi-resolution training, this MUST be provided
Returns:
CDC cache path examples:
- With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz"
- With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
- Without hash: "image_0512x0768_flux_cdc.npz"
Example multi-resolution scenario:
resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz"
resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz"
"""
path = Path(latents_npz_path)
# Build filename components
components = [path.stem, "cdc"]
# Add latent resolution if provided (for multi-resolution training)
if latent_shape is not None:
if len(latent_shape) >= 3:
# Format: HxW (e.g., "104x80" from shape (16, 104, 80))
h, w = latent_shape[-2], latent_shape[-1]
components.append(f"{h}x{w}")
else:
raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}")
# Add config hash if provided
if config_hash:
components.append(config_hash)
# Build final filename
new_stem = "_".join(components)
return str(path.with_stem(new_stem))
def compute_all(self) -> int:
"""
Compute Γ_b for all added latents and save individual CDC files next to each latent cache
Returns:
Number of CDC files saved
"""
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
if self.adaptive_k:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
else:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
if self.adaptive_k:
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
mode = "adaptive" if self.adaptive_k else "fixed"
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
# Process each bucket with progress bar
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
for shape, samples in bucket_iter:
num_samples = len(samples)
if self.debug:
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Determine effective k for this bucket
if self.adaptive_k:
# Adaptive mode: skip if below minimum, otherwise use best available k
if num_samples < min_threshold:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
# Use adaptive k for this bucket
k_effective = min(k_neighbors, num_samples - 1)
else:
# Fixed mode: skip if below k_neighbors
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
k_effective = k_neighbors
# Collect latents (no resizing needed - all same shape)
latents_list = []
global_indices = []
for sample in samples:
global_indices.append(sample.global_idx)
latents_list.append(sample.latent) # Already flattened
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch with effective k
if self.debug:
if self.adaptive_k and k_effective < k_neighbors:
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
else:
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
# Temporarily override k for this bucket
original_k = self.computer.k
self.computer.k = k_effective
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
self.computer.k = original_k
# No resizing needed - eigenvectors are already correct size
if self.debug:
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
# Merge into overall results
all_results.update(batch_results)
# Save individual CDC files next to each latent cache
if self.debug:
print(f"\n{'='*60}")
print("Saving individual CDC files...")
print(f"{'='*60}")
files_saved = 0
total_size = 0
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
for sample in save_iter:
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape)
# Get CDC results for this sample
if sample.global_idx in all_results:
eigvecs, eigvals = all_results[sample.global_idx]
# Convert to numpy if needed
if isinstance(eigvecs, torch.Tensor):
eigvecs = eigvecs.numpy()
if isinstance(eigvals, torch.Tensor):
eigvals = eigvals.numpy()
# Save metadata and CDC results
np.savez(
cdc_path,
eigenvectors=eigvecs,
eigenvalues=eigvals,
shape=np.array(sample.shape),
k_neighbors=self.computer.k,
d_cdc=self.computer.d_cdc,
gamma=self.computer.gamma
)
files_saved += 1
total_size += Path(cdc_path).stat().st_size
logger.debug(f"Saved CDC file: {cdc_path}")
total_size_mb = total_size / 1024 / 1024
logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB")
return files_saved
class GammaBDataset:
"""
Efficient loader for Γ_b matrices during training
Loads from individual CDC cache files next to latent caches
"""
def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None):
"""
Initialize CDC dataset loader
Args:
device: Device for loading tensors
config_hash: Optional config hash to use for CDC file lookup.
If None, uses default naming without hash.
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.config_hash = config_hash
if config_hash:
logger.info(f"CDC loader initialized (hash: {config_hash})")
else:
logger.info("CDC loader initialized (no hash, backward compatibility mode)")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
latents_npz_paths: List[str],
device: Optional[str] = None,
latent_shape: Optional[Tuple[int, ...]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of latents
Args:
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
device: Device to load to (defaults to self.device)
latent_shape: Latent shape (C, H, W) to identify which CDC file to load
Required for multi-resolution training to avoid loading wrong CDC
Returns:
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
eigenvalues: (B, d_cdc)
Note:
For multi-resolution training, latent_shape MUST be provided to load the correct
CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch.
"""
if device is None:
device = self.device
eigenvectors_list = []
eigenvalues_list = []
for latents_npz_path in latents_npz_paths:
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape)
# Load CDC data
if not Path(cdc_path).exists():
raise FileNotFoundError(
f"CDC cache file not found: {cdc_path}. "
f"Make sure to run CDC preprocessing before training."
)
data = np.load(cdc_path)
eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float()
eigvals = torch.from_numpy(data['eigenvalues']).to(device).float()
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
# Check if all eigenvectors have the same dimension
dims = [ev.shape[1] for ev in eigenvectors_list]
if len(set(dims)) > 1:
# Dimension mismatch! This shouldn't happen with proper bucketing
# but can occur if batch contains mixed sizes
raise RuntimeError(
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
f"Latent paths: {latents_npz_paths}. "
f"This means the training batch contains images of different sizes, "
f"which violates CDC's requirement for uniform latent dimensions per batch. "
f"Check that your dataloader buckets are configured correctly."
)
eigenvectors = torch.stack(eigenvectors_list, dim=0)
eigenvalues = torch.stack(eigenvalues_list, dim=0)
return eigenvectors, eigenvalues
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
Args:
eigenvectors: (B, d_cdc, d)
eigenvalues: (B, d_cdc)
x: (B, d) or (B, C, H, W) - will be flattened if needed
t: (B,) or scalar time
Returns:
result: Same shape as input x
Note:
Gradients flow through this function for backprop during training.
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
if t.dim() == 0:
t = t.expand(x.shape[0])
t = t.view(-1, 1)
# Early return for t=0 to avoid numerical errors
if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8):
return x.reshape(orig_shape)
# Check if CDC is disabled (all eigenvalues are zero)
# This happens for buckets with < k_neighbors samples
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
# Fallback to standard Gaussian noise (no CDC correction)
return x.reshape(orig_shape)
# Γ_b^(1/2) @ x using low-rank representation
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Σ_t @ x
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result

View File

@@ -2,10 +2,8 @@ import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
from accelerate import Accelerator, PartialState
@@ -183,7 +181,7 @@ def sample_image_inference(
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info("negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
@@ -468,9 +466,91 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Global set to track samples that have already been warned about shape mismatches
# This prevents log spam during training (warning once per sample is sufficient)
_cdc_warned_samples = set()
def apply_cdc_noise_transformation(
noise: torch.Tensor,
timesteps: torch.Tensor,
num_timesteps: int,
gamma_b_dataset,
latents_npz_paths,
device
) -> torch.Tensor:
"""
Apply CDC-FM geometry-aware noise transformation.
Args:
noise: (B, C, H, W) standard Gaussian noise
timesteps: (B,) timesteps for this batch
num_timesteps: Total number of timesteps in scheduler
gamma_b_dataset: GammaBDataset with cached CDC matrices
latents_npz_paths: List of latent cache paths for this batch
device: Device to load CDC matrices to
Returns:
Transformed noise with geometry-aware covariance
"""
# Device consistency validation
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
target_device = torch.device(device) if not isinstance(device, torch.device) else device
noise_device = noise.device
# Check if devices are compatible (cuda:0 vs cuda should not warn)
devices_compatible = (
noise_device == target_device or
(noise_device.type == "cuda" and target_device.type == "cuda") or
(noise_device.type == "cpu" and target_device.type == "cpu")
)
if not devices_compatible:
logger.warning(
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
f"Transferring noise to {target_device} to avoid errors."
)
noise = noise.to(target_device)
device = target_device
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps.to(device) / num_timesteps
B, C, H, W = noise.shape
# Batch processing: Get CDC data for all samples at once
# Pass latent shape for multi-resolution CDC support
latent_shape = (C, H, W)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)
def get_noisy_model_input_and_timestep(
args,
noise_scheduler,
latents: torch.Tensor,
noise: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
gamma_b_dataset=None,
latents_npz_paths=None,
timestep_index: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate noisy model input and corresponding timesteps for training.
Args:
args: Configuration with sampling parameters
noise_scheduler: Scheduler for noise/timestep management
latents: Clean latent representations
noise: Random noise tensor
device: Target device
dtype: Target dtype
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided)
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
@@ -514,10 +594,20 @@ def get_noisy_model_input_and_timesteps(
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
# Apply CDC-FM geometry-aware noise transformation if enabled
if gamma_b_dataset is not None and latents_npz_paths is not None:
noise = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=num_timesteps,
gamma_b_dataset=gamma_b_dataset,
latents_npz_paths=latents_npz_paths,
device=device
)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:

View File

@@ -40,6 +40,8 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from library.cdc_fm import CDCPreprocessor
from diffusers.optimization import (
SchedulerType as DiffusersSchedulerType,
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
@@ -1572,11 +1574,17 @@ class BaseDataset(torch.utils.data.Dataset):
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
custom_attributes = []
image_keys = [] # CDC-FM: track image keys for CDC lookup
latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
# CDC-FM: Store image_key and latents_npz path for CDC lookup
image_keys.append(image_key)
latents_npz_paths.append(image_info.latents_npz)
custom_attributes.append(subset.custom_attributes)
# in case of fine tuning, is_reg is always False
@@ -1818,6 +1826,9 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
# CDC-FM: Add latents_npz paths to batch for CDC lookup
example["latents_npz"] = latents_npz_paths
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -2647,6 +2658,220 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def cache_cdc_gamma_b(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
) -> Optional[str]:
"""
Cache CDC Γ_b matrices for all latents in the dataset
CDC files are saved as individual .npz files next to each latent cache file.
For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc_a1b2c3d4.npz
where 'a1b2c3d4' is the config hash (dataset dirs + CDC params).
Args:
k_neighbors: k-NN neighbors
k_bandwidth: Bandwidth estimation neighbors
d_cdc: CDC subspace dimension
gamma: CDC strength
force_recache: Force recompute even if cache exists
accelerator: For multi-GPU support
debug: Enable debug logging
adaptive_k: Enable adaptive k selection for small buckets
min_bucket_size: Minimum bucket size for CDC computation
Returns:
Config hash string for this CDC configuration, or None on error
"""
from pathlib import Path
# Validate that latent caching is enabled
# CDC requires latents to be cached (either to disk or in memory) because:
# 1. CDC files are named based on latent cache filenames
# 2. CDC files are saved next to latent cache files
# 3. Training needs latent paths to load corresponding CDC files
has_cached_latents = False
for dataset in self.datasets:
for info in dataset.image_data.values():
if info.latents is not None or info.latents_npz is not None:
has_cached_latents = True
break
if has_cached_latents:
break
if not has_cached_latents:
raise ValueError(
"CDC-FM requires latent caching to be enabled. "
"Please enable latent caching by setting one of:\n"
" - cache_latents = true (cache in memory)\n"
" - cache_latents_to_disk = true (cache to disk)\n"
"in your training config or command line arguments."
)
# Collect dataset/subset directories for config hash
dataset_dirs = []
for dataset in self.datasets:
# Get the directory containing the images
if hasattr(dataset, 'image_dir'):
dataset_dirs.append(str(dataset.image_dir))
# Fallback: use first image's parent directory
elif dataset.image_data:
first_image = next(iter(dataset.image_data.values()))
dataset_dirs.append(str(Path(first_image.absolute_path).parent))
# Create preprocessor to get config hash
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device="cuda" if torch.cuda.is_available() else "cpu",
debug=debug,
adaptive_k=adaptive_k,
min_bucket_size=min_bucket_size,
dataset_dirs=dataset_dirs
)
logger.info(f"CDC config hash: {preprocessor.config_hash}")
# Check if CDC caches already exist (unless force_recache)
if not force_recache:
all_cached = self._check_cdc_caches_exist(preprocessor.config_hash)
if all_cached:
logger.info("All CDC cache files found, skipping preprocessing")
return preprocessor.config_hash
else:
logger.info("Some CDC cache files missing, will compute")
# Only main process computes CDC
is_main = accelerator is None or accelerator.is_main_process
if not is_main:
if accelerator is not None:
accelerator.wait_for_everyone()
return preprocessor.config_hash
logger.info("Starting CDC-FM preprocessing")
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
# Get caching strategy for loading latents
from library.strategy_base import LatentsCachingStrategy
caching_strategy = LatentsCachingStrategy.get_strategy()
# Collect all latents from all datasets
for dataset_idx, dataset in enumerate(self.datasets):
logger.info(f"Loading latents from dataset {dataset_idx}...")
image_infos = list(dataset.image_data.values())
for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")):
# Load latent from disk or memory
if info.latents is not None:
latent = info.latents
elif info.latents_npz is not None:
# Load from disk
latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso)
if latent is None:
logger.warning(f"Failed to load latent from {info.latents_npz}, skipping")
continue
else:
logger.warning(f"No latent found for {info.absolute_path}, skipping")
continue
# Add to preprocessor (with unique global index across all datasets)
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
# Get latents_npz_path - will be set whether caching to disk or memory
if info.latents_npz is None:
# If not set, generate the path from the caching strategy
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso)
preprocessor.add_latent(
latent=latent,
global_idx=actual_global_idx,
latents_npz_path=info.latents_npz,
shape=latent.shape,
metadata={"image_key": info.image_key}
)
# Compute and save individual CDC files
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
files_saved = preprocessor.compute_all()
logger.info(f"Saved {files_saved} CDC cache files")
if accelerator is not None:
accelerator.wait_for_everyone()
# Return config hash so training can initialize GammaBDataset with it
return preprocessor.config_hash
def _check_cdc_caches_exist(self, config_hash: str) -> bool:
"""
Check if CDC cache files exist for all latents in the dataset
Args:
config_hash: The config hash to use for CDC filename lookup
"""
from pathlib import Path
missing_count = 0
total_count = 0
for dataset in self.datasets:
for info in dataset.image_data.values():
total_count += 1
if info.latents_npz is None:
# If latents_npz not set, we can't check for CDC cache
continue
# Compute expected latent shape from bucket_reso
# For multi-resolution CDC, we need to pass latent_shape to get the correct filename
latent_shape = None
if info.bucket_reso is not None:
# Get latent shape efficiently without loading full data
# First check if latent is already in memory
if info.latents is not None:
latent_shape = info.latents.shape
else:
# Load latent shape from npz file metadata
# This is faster than loading the full latent data
try:
import numpy as np
with np.load(info.latents_npz) as data:
# Find the key for this bucket resolution
# Multi-resolution format uses keys like "latents_104x80"
h, w = info.bucket_reso[1] // 8, info.bucket_reso[0] // 8
key = f"latents_{h}x{w}"
if key in data:
latent_shape = data[key].shape
elif 'latents' in data:
# Fallback for single-resolution cache
latent_shape = data['latents'].shape
except Exception as e:
logger.debug(f"Failed to read latent shape from {info.latents_npz}: {e}")
# Fall back to checking without shape (backward compatibility)
latent_shape = None
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash, latent_shape)
if not Path(cdc_path).exists():
missing_count += 1
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Missing CDC cache: {cdc_path}")
if missing_count > 0:
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")
return False
logger.debug(f"All {total_count} CDC cache files exist")
return True
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
@@ -6064,8 +6289,19 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
args, noise_scheduler, latents: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
"""
Sample noise and create noisy latents.
Args:
args: Training arguments
noise_scheduler: The noise scheduler
latents: Clean latents
Returns:
(noise, noisy_latents, timesteps)
"""
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:

View File

@@ -0,0 +1,183 @@
import torch
from typing import Union
class MockGammaBDataset:
"""
Mock implementation of GammaBDataset for testing gradient flow
"""
def __init__(self, *args, **kwargs):
"""
Simple initialization that doesn't require file loading
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Simplified implementation of compute_sigma_t_x for testing
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
# Validate dimensions
assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch"
assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch"
# Early return for t=0 with gradient preservation
if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad:
return x.reshape(orig_shape)
# Compute Σ_t @ x
# V^T x
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
# sqrt(λ) * V^T x
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
# V @ (sqrt(λ) * V^T x)
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Interpolate between original and noisy latent
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result
class TestCDCAdvanced:
def setup_method(self):
"""Prepare consistent test environment"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_gradient_flow_preservation(self):
"""
Verify that gradient flow is preserved even for near-zero time steps
with learnable time embeddings
"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create a learnable time embedding with small initial value
t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
assert t.grad is not None, "Time embedding gradient should be computed"
assert latent.grad is not None, "Input latent gradient should be computed"
# Check gradient magnitudes are non-zero
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}"
# Optional: Print gradient details for debugging
print(f"Time embedding gradient magnitude: {t_grad_magnitude}")
print(f"Latent gradient magnitude: {latent_grad_magnitude}")
def test_gradient_flow_with_different_time_steps(self):
"""
Verify gradient flow across different time step values
"""
# Test time steps
time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0]
for time_val in time_steps:
# Create a learnable time embedding
t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32)
# Generate mock latent and CDC components
batch_size, latent_dim = 4, 64
latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True)
# Create mock eigenvectors and eigenvalues
eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device)
eigenvalues = torch.rand(batch_size, 8, device=self.device)
# Ensure eigenvectors and eigenvalues are meaningful
eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True)
eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0)
# Use the mock dataset
mock_dataset = MockGammaBDataset()
# Compute noisy latent with gradient tracking
noisy_latent = mock_dataset.compute_sigma_t_x(
eigenvectors,
eigenvalues,
latent,
t
)
# Compute a dummy loss to check gradient flow
loss = noisy_latent.sum()
# Compute gradients
loss.backward()
# Assertions to verify gradient flow
t_grad_magnitude = torch.abs(t.grad).sum()
latent_grad_magnitude = torch.abs(latent.grad).sum()
assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}"
assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}"
# Reset gradients for next iteration
if t.grad is not None:
t.grad.zero_()
if latent.grad is not None:
latent.grad.zero_()
def pytest_configure(config):
"""
Add custom markers for CDC-FM tests
"""
config.addinivalue_line(
"markers",
"gradient_flow: mark test to verify gradient preservation in CDC Flow Matching"
)

View File

@@ -0,0 +1,285 @@
"""
Test CDC cache detection with multi-resolution filenames
This test verifies that _check_cdc_caches_exist() correctly detects CDC cache files
that include resolution information in their filenames (e.g., image_flux_cdc_104x80_hash.npz).
This was a bug where the check was looking for files without resolution
(image_flux_cdc_hash.npz) while the actual files had resolution in the name.
"""
import os
import tempfile
import shutil
from pathlib import Path
import numpy as np
import pytest
from library.train_util import DatasetGroup, ImageInfo
from library.cdc_fm import CDCPreprocessor
class MockDataset:
"""Mock dataset for testing"""
def __init__(self, image_data):
self.image_data = image_data
self.image_dir = "/mock/dataset"
self.num_train_images = len(image_data)
self.num_reg_images = 0
def __len__(self):
return len(self.image_data)
def test_cdc_cache_detection_with_resolution():
"""
Test that CDC cache files with resolution in filename are properly detected.
This reproduces the bug where:
- CDC files are created with resolution: image_flux_cdc_104x80_hash.npz
- But check looked for: image_flux_cdc_hash.npz
- Result: Files not detected, unnecessary regeneration
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Setup: Create a mock latent cache file and corresponding CDC cache
config_hash = "test1234"
# Create latent cache file with multi-resolution format
latent_path = Path(tmpdir) / "image_0832x0640_flux.npz"
latent_shape = (16, 104, 80) # C, H, W for resolution 832x640 (832/8=104, 640/8=80)
# Save a mock latent file
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Create the CDC cache file with resolution in filename (as it's actually created)
cdc_path = CDCPreprocessor.get_cdc_npz_path(
str(latent_path),
config_hash,
latent_shape
)
# Verify the CDC path includes resolution
assert "104x80" in cdc_path, f"CDC path should include resolution: {cdc_path}"
# Create a mock CDC file
np.savez(
cdc_path,
eigenvectors=np.random.randn(8, 16*104*80).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# Setup mock dataset
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (640, 832) # W, H (note: reversed from latent shape H,W)
image_info.latents = None # Not in memory
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return True since the CDC file exists
assert result is True, "CDC cache file should be detected when it exists with resolution in filename"
def test_cdc_cache_detection_missing_file():
"""
Test that missing CDC cache files are correctly identified as missing.
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "test5678"
# Create latent cache file but NO CDC cache
latent_path = Path(tmpdir) / "image_0768x0512_flux.npz"
latent_shape = (16, 96, 64) # C, H, W
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Setup mock dataset (CDC file does NOT exist)
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (512, 768) # W, H
image_info.latents = None
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return False since CDC file doesn't exist
assert result is False, "Should detect that CDC cache file is missing"
def test_cdc_cache_detection_with_in_memory_latent():
"""
Test CDC cache detection when latent is already in memory (faster path).
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "test_mem1"
# Create latent cache file path (file may or may not exist)
latent_path = Path(tmpdir) / "image_1024x1024_flux.npz"
latent_shape = (16, 128, 128) # C, H, W
# Create the CDC cache file
cdc_path = CDCPreprocessor.get_cdc_npz_path(
str(latent_path),
config_hash,
latent_shape
)
np.savez(
cdc_path,
eigenvectors=np.random.randn(8, 16*128*128).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# Setup mock dataset with latent in memory
import torch
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (1024, 1024) # W, H
image_info.latents = torch.randn(latent_shape) # In memory!
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected (should use faster in-memory path)
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return True
assert result is True, "CDC cache should be detected using in-memory latent shape"
def test_cdc_cache_detection_partial_cache():
"""
Test that partial cache (some files exist, some don't) is correctly identified.
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "testpart"
# Create two latent files
latent_path1 = Path(tmpdir) / "image1_0640x0512_flux.npz"
latent_path2 = Path(tmpdir) / "image2_0640x0512_flux.npz"
latent_shape = (16, 80, 64)
for latent_path in [latent_path1, latent_path2]:
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Create CDC cache for ONLY the first image
cdc_path1 = CDCPreprocessor.get_cdc_npz_path(str(latent_path1), config_hash, latent_shape)
np.savez(
cdc_path1,
eigenvectors=np.random.randn(8, 16*80*64).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# CDC cache for second image does NOT exist
# Setup mock dataset with both images
info1 = ImageInfo("img1", 1, "test", False, str(Path(tmpdir) / "img1.png"))
info1.latents_npz = str(latent_path1)
info1.bucket_reso = (512, 640)
info1.latents = None
info2 = ImageInfo("img2", 1, "test", False, str(Path(tmpdir) / "img2.png"))
info2.latents_npz = str(latent_path2)
info2.bucket_reso = (512, 640)
info2.latents = None
mock_dataset = MockDataset({"img1": info1, "img2": info2})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if all CDC caches exist
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return False since not all files exist
assert result is False, "Should detect that some CDC cache files are missing"
def test_cdc_requires_latent_caching():
"""
Test that CDC-FM gives a clear error when latent caching is not enabled.
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Setup mock dataset with NO latent caching (both latents and latents_npz are None)
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = None # No disk cache
image_info.latents = None # No memory cache
image_info.bucket_reso = (512, 512)
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Attempt to cache CDC without latent caching enabled
with pytest.raises(ValueError) as exc_info:
dataset_group.cache_cdc_gamma_b(
k_neighbors=256,
k_bandwidth=8,
d_cdc=8,
gamma=1.0
)
# Verify: Error message should mention latent caching requirement
error_message = str(exc_info.value)
assert "CDC-FM requires latent caching" in error_message
assert "cache_latents" in error_message
assert "cache_latents_to_disk" in error_message
if __name__ == "__main__":
# Run tests with verbose output
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,157 @@
"""
Test CDC config hash generation and cache invalidation
"""
import pytest
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor
class TestCDCConfigHash:
"""
Test that CDC config hash properly invalidates cache when dataset or parameters change
"""
def test_same_config_produces_same_hash(self, tmp_path):
"""
Test that identical configurations produce identical hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_different_dataset_dirs_produce_different_hash(self, tmp_path):
"""
Test that different dataset directories produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset2")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_k_neighbors_produces_different_hash(self, tmp_path):
"""
Test that different k_neighbors values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_d_cdc_produces_different_hash(self, tmp_path):
"""
Test that different d_cdc values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_different_gamma_produces_different_hash(self, tmp_path):
"""
Test that different gamma values produce different hashes
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash != preprocessor2.config_hash
def test_multiple_dataset_dirs_order_independent(self, tmp_path):
"""
Test that dataset directory order doesn't affect hash (they are sorted)
"""
preprocessor1 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")]
)
preprocessor2 = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")]
)
assert preprocessor1.config_hash == preprocessor2.config_hash
def test_hash_length_is_8_chars(self, tmp_path):
"""
Test that hash is exactly 8 characters (hex)
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
assert len(preprocessor.config_hash) == 8
# Verify it's hex
int(preprocessor.config_hash, 16) # Should not raise
def test_filename_includes_hash(self, tmp_path):
"""
Test that CDC filenames include the config hash
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0,
device="cpu", dataset_dirs=[str(tmp_path / "dataset1")]
)
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash)
# Should be: image_0512x0768_flux_cdc_<hash>.npz
expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz")
assert cdc_path == expected
def test_backward_compatibility_no_hash(self, tmp_path):
"""
Test that get_cdc_npz_path works without hash (backward compatibility)
"""
latents_path = str(tmp_path / "image_0512x0768_flux.npz")
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None)
# Should be: image_0512x0768_flux_cdc.npz (no hash suffix)
expected = str(tmp_path / "image_0512x0768_flux_cdc.npz")
assert cdc_path == expected
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,234 @@
"""
Test CDC-FM multi-resolution support
This test verifies that CDC files are correctly created and loaded for different
resolutions, preventing dimension mismatch errors in multi-resolution training.
"""
import torch
import numpy as np
from pathlib import Path
import pytest
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCMultiResolution:
"""Test CDC multi-resolution caching and loading"""
def test_different_resolutions_create_separate_cdc_files(self, tmp_path):
"""
Test that the same image with different latent resolutions creates
separate CDC cache files.
"""
# Create preprocessor
preprocessor = CDCPreprocessor(
k_neighbors=5,
k_bandwidth=3,
d_cdc=4,
gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Same image, two different resolutions
image_base_path = str(tmp_path / "test_image_1200x1500_flux.npz")
# Resolution 1: 64x48 (simulating resolution=512 training)
latent_64x48 = torch.randn(16, 64, 48, dtype=torch.float32)
for i in range(10): # Need multiple samples for CDC
preprocessor.add_latent(
latent=latent_64x48,
global_idx=i,
latents_npz_path=image_base_path,
shape=latent_64x48.shape,
metadata={'image_key': f'test_image_{i}'}
)
# Compute and save
files_saved = preprocessor.compute_all()
assert files_saved == 10
# Verify CDC file for 64x48 exists with shape in filename
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
image_base_path,
preprocessor.config_hash,
latent_shape=(16, 64, 48)
)
assert Path(cdc_path_64x48).exists()
assert "64x48" in cdc_path_64x48
# Create new preprocessor for resolution 2
preprocessor2 = CDCPreprocessor(
k_neighbors=5,
k_bandwidth=3,
d_cdc=4,
gamma=1.0,
device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Resolution 2: 104x80 (simulating resolution=768 training)
latent_104x80 = torch.randn(16, 104, 80, dtype=torch.float32)
for i in range(10):
preprocessor2.add_latent(
latent=latent_104x80,
global_idx=i,
latents_npz_path=image_base_path,
shape=latent_104x80.shape,
metadata={'image_key': f'test_image_{i}'}
)
files_saved2 = preprocessor2.compute_all()
assert files_saved2 == 10
# Verify CDC file for 104x80 exists with different shape in filename
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
image_base_path,
preprocessor2.config_hash,
latent_shape=(16, 104, 80)
)
assert Path(cdc_path_104x80).exists()
assert "104x80" in cdc_path_104x80
# Verify both files exist and are different
assert cdc_path_64x48 != cdc_path_104x80
assert Path(cdc_path_64x48).exists()
assert Path(cdc_path_104x80).exists()
# Verify the CDC files have different dimensions
data_64x48 = np.load(cdc_path_64x48)
data_104x80 = np.load(cdc_path_104x80)
# 64x48 -> flattened dim = 16 * 64 * 48 = 49152
# 104x80 -> flattened dim = 16 * 104 * 80 = 133120
assert data_64x48['eigenvectors'].shape[1] == 16 * 64 * 48
assert data_104x80['eigenvectors'].shape[1] == 16 * 104 * 80
def test_loading_correct_cdc_for_resolution(self, tmp_path):
"""
Test that GammaBDataset loads the correct CDC file based on latent_shape
"""
# Create and save CDC files for two resolutions
config_hash = "testHash"
image_path = str(tmp_path / "test_image_flux.npz")
# Create CDC file for 64x48
cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=(16, 64, 48)
)
eigvecs_64x48 = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
eigvals_64x48 = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_64x48,
eigenvectors=eigvecs_64x48,
eigenvalues=eigvals_64x48,
shape=np.array([16, 64, 48])
)
# Create CDC file for 104x80
cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=(16, 104, 80)
)
eigvecs_104x80 = np.random.randn(4, 16 * 104 * 80).astype(np.float16)
eigvals_104x80 = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_104x80,
eigenvectors=eigvecs_104x80,
eigenvalues=eigvals_104x80,
shape=np.array([16, 104, 80])
)
# Create GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Load with 64x48 shape
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=(16, 64, 48)
)
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
# Load with 104x80 shape
eigvecs_loaded2, eigvals_loaded2 = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=(16, 104, 80)
)
assert eigvecs_loaded2.shape == (1, 4, 16 * 104 * 80)
# Verify different dimensions were loaded
assert eigvecs_loaded.shape[2] != eigvecs_loaded2.shape[2]
def test_error_when_latent_shape_not_provided_for_multireso(self, tmp_path):
"""
Test that loading without latent_shape still works for backward compatibility
but will use old filename format without resolution
"""
config_hash = "testHash"
image_path = str(tmp_path / "test_image_flux.npz")
# Create CDC file with old naming (no latent shape)
cdc_path_old = CDCPreprocessor.get_cdc_npz_path(
image_path,
config_hash,
latent_shape=None # Old format
)
eigvecs = np.random.randn(4, 16 * 64 * 48).astype(np.float16)
eigvals = np.random.randn(4).astype(np.float16)
np.savez(
cdc_path_old,
eigenvectors=eigvecs,
eigenvalues=eigvals,
shape=np.array([16, 64, 48])
)
# Load without latent_shape (backward compatibility)
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt(
[image_path],
device="cpu",
latent_shape=None
)
assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48)
def test_filename_format_with_latent_shape(self):
"""Test that CDC filenames include latent dimensions correctly"""
base_path = "/path/to/image_1200x1500_flux.npz"
config_hash = "abc123de"
# With latent shape
cdc_path = CDCPreprocessor.get_cdc_npz_path(
base_path,
config_hash,
latent_shape=(16, 104, 80)
)
# Should include latent H×W in filename
assert "104x80" in cdc_path
assert config_hash in cdc_path
assert cdc_path.endswith("_flux_cdc_104x80_abc123de.npz")
def test_filename_format_without_latent_shape(self):
"""Test backward compatible filename without latent shape"""
base_path = "/path/to/image_1200x1500_flux.npz"
config_hash = "abc123de"
# Without latent shape (old format)
cdc_path = CDCPreprocessor.get_cdc_npz_path(
base_path,
config_hash,
latent_shape=None
)
# Should NOT include latent dimensions
assert "104x80" not in cdc_path
assert "64x48" not in cdc_path
assert config_hash in cdc_path
assert cdc_path.endswith("_flux_cdc_abc123de.npz")

View File

@@ -0,0 +1,322 @@
"""
CDC Preprocessor and Device Consistency Tests
This module provides testing of:
1. CDC Preprocessor functionality
2. Device consistency handling
3. GammaBDataset loading and usage
4. End-to-end CDC workflow verification
"""
import pytest
import logging
import torch
from pathlib import Path
from safetensors.torch import save_file
from safetensors import safe_open
from library.cdc_fm import CDCPreprocessor, GammaBDataset
from library.flux_train_utils import apply_cdc_noise_transformation
class TestCDCPreprocessorIntegration:
"""
Comprehensive testing of CDC preprocessing and device handling
"""
def test_basic_preprocessor_workflow(self, tmp_path):
"""
Test basic CDC preprocessing with small dataset
"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure (with config hash and latent shape)
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
import numpy as np
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_preprocessor_with_different_shapes(self, tmp_path):
"""
Test CDC preprocessing with variable-size latents (bucketing)
"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
assert files_saved == 10
import numpy as np
# Check shapes are stored in individual files (with config hash and latent shape)
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestDeviceConsistency:
"""
Test device handling and consistency for CDC transformations
"""
def test_matching_devices_no_warning(self, tmp_path, caplog):
"""
Test that no warnings are emitted when devices match.
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu")
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
latents_npz_paths_batch = latents_npz_paths[:2]
with caplog.at_level(logging.WARNING):
caplog.clear()
_ = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
# No device mismatch warnings
device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()]
assert len(device_warnings) == 0, "Should not warn when devices match"
def test_device_mismatch_handling(self, tmp_path):
"""
Test that CDC transformation handles device mismatch gracefully
"""
# Create CDC cache on CPU
preprocessor = CDCPreprocessor(
k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
shape = (16, 32, 32)
latents_npz_paths = []
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=shape,
metadata=metadata
)
preprocessor.compute_all()
dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Create noise and timesteps
noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True)
timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu")
latents_npz_paths_batch = latents_npz_paths[:2]
# Perform CDC transformation
result = apply_cdc_noise_transformation(
noise=noise,
timesteps=timesteps,
num_timesteps=1000,
gamma_b_dataset=dataset,
latents_npz_paths=latents_npz_paths_batch,
device="cpu"
)
# Verify output characteristics
assert result.shape == noise.shape
assert result.device == noise.device
assert result.requires_grad # Gradients should still work
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()
# Verify gradients flow
loss = result.sum()
loss.backward()
assert noise.grad is not None
class TestCDCEndToEnd:
"""
End-to-end CDC workflow tests
"""
def test_full_preprocessing_usage_workflow(self, tmp_path):
"""
Test complete workflow: preprocess -> save -> load -> use
"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset (use config hash)
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components (pass latent_shape for multi-resolution support)
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
def pytest_configure(config):
"""
Configure custom markers for CDC tests
"""
config.addinivalue_line(
"markers",
"device_consistency: mark test to verify device handling in CDC transformations"
)
config.addinivalue_line(
"markers",
"preprocessor: mark test to verify CDC preprocessing workflow"
)
config.addinivalue_line(
"markers",
"end_to_end: mark test to verify full CDC workflow"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,299 @@
"""
Standalone tests for CDC-FM per-file caching.
These tests focus on the current CDC-FM per-file caching implementation
with hash-based cache validation.
"""
from pathlib import Path
import pytest
import torch
import numpy as np
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestCDCPreprocessor:
"""Test CDC preprocessing functionality with per-file caching"""
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
"""Test basic CDC preprocessing with small dataset"""
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 10 small latents
for i in range(10):
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save (creates per-file CDC caches)
files_saved = preprocessor.compute_all()
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
data = np.load(cdc_path)
assert data['k_neighbors'] == 5
assert data['d_cdc'] == 4
# Check eigenvectors and eigenvalues
eigvecs = data['eigenvectors']
eigvals = data['eigenvalues']
assert eigvecs.shape[0] == 4 # d_cdc
assert eigvals.shape[0] == 4 # d_cdc
def test_cdc_preprocessor_different_shapes(self, tmp_path):
"""Test CDC preprocessing with variable-size latents (bucketing)"""
preprocessor = CDCPreprocessor(
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
# Add 5 latents of shape (16, 4, 4)
for i in range(5):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Add 5 latents of different shape (16, 8, 8)
for i in range(5, 10):
latent = torch.randn(16, 8, 8, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz")
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
# Compute and save
files_saved = preprocessor.compute_all()
# Verify both shape groups were processed
assert files_saved == 10
# Check shapes are stored in individual files
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
assert tuple(data_0['shape']) == (16, 4, 4)
assert tuple(data_5['shape']) == (16, 8, 8)
class TestGammaBDataset:
"""Test GammaBDataset loading and retrieval with per-file caching"""
@pytest.fixture
def sample_cdc_cache(self, tmp_path):
"""Create sample CDC cache files for testing"""
# Use 20 samples to ensure proper k-NN computation
# (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing)
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)],
adaptive_k=True, # Enable adaptive k for small dataset
min_bucket_size=5
)
# Create 20 samples
latents_npz_paths = []
for i in range(20):
latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened
latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
preprocessor.compute_all()
return tmp_path, latents_npz_paths, preprocessor.config_hash
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
"""Test that GammaBDataset loads CDC files correctly"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get components for first sample
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigvecs.shape[0] == 1 # batch size
assert eigvecs.shape[1] == 4 # d_cdc
assert eigvals.shape == (1, 4) # batch, d_cdc
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
"""Test retrieving Γ_b^(1/2) components"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get Γ_b for paths [0, 2, 4]
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
latent_shape = (16, 8, 8)
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigenvectors.shape[0] == 3 # batch
assert eigenvectors.shape[1] == 4 # d_cdc
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
# Check values are positive
assert torch.all(eigenvalues > 0)
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns x unchanged at t=0"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Create test latents (batch of 3, matching d=1024 flattened)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.zeros(3) # t = 0 for all samples
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# At t=0, should return x unchanged
assert torch.allclose(sigma_t_x, x, atol=1e-6)
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
"""Test compute_sigma_t_x returns correct shape"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(2, 1024) # B, d (flattened)
t = torch.tensor([0.3, 0.7])
# Get Γ_b components
paths = [latents_npz_paths[1], latents_npz_paths[3]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should return same shape as input
assert sigma_t_x.shape == x.shape
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
"""Test compute_sigma_t_x produces finite values"""
tmp_path, latents_npz_paths, config_hash = sample_cdc_cache
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
x = torch.randn(3, 1024) # B, d (flattened)
t = torch.rand(3) # Random timesteps in [0, 1]
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
# Should not contain NaNs or Infs
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
class TestCDCEndToEnd:
"""End-to-end CDC workflow tests"""
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
"""Test complete workflow: preprocess -> save -> load -> use"""
# Step 1: Preprocess latents
preprocessor = CDCPreprocessor(
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu",
dataset_dirs=[str(tmp_path)]
)
num_samples = 10
latents_npz_paths = []
for i in range(num_samples):
latent = torch.randn(16, 4, 4, dtype=torch.float32)
latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz")
latents_npz_paths.append(latents_npz_path)
metadata = {'image_key': f'test_image_{i}'}
preprocessor.add_latent(
latent=latent,
global_idx=i,
latents_npz_path=latents_npz_path,
shape=latent.shape,
metadata=metadata
)
files_saved = preprocessor.compute_all()
assert files_saved == num_samples
# Step 2: Load with GammaBDataset
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash)
# Step 3: Use in mock training scenario
batch_size = 3
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
batch_t = torch.rand(batch_size)
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
# Verify output is reasonable
assert sigma_t_x.shape == batch_latents_flat.shape
assert not torch.isnan(sigma_t_x).any()
assert torch.isfinite(sigma_t_x).all()
# Verify that noise changes with different timesteps
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
# At t=0, should be close to x; at t=1, should be different
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -2,7 +2,7 @@ import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
get_noisy_model_input_and_timestep,
)
# Mock classes and functions
@@ -66,7 +66,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -80,7 +80,7 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -93,7 +93,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -105,7 +105,7 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -126,7 +126,7 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, device, dtype
)
@@ -141,7 +141,7 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -153,7 +153,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -164,7 +164,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
@@ -176,7 +176,7 @@ def test_different_batch_size(args, noise_scheduler, device):
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
@@ -189,7 +189,7 @@ def test_different_image_size(args, noise_scheduler, device):
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
@@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
@@ -212,7 +212,7 @@ def test_different_timestep_count(args, device):
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)

View File

@@ -622,6 +622,27 @@ class NetworkTrainer:
accelerator.wait_for_everyone()
# CDC-FM preprocessing
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
self.cdc_config_hash = train_dataset_group.cache_cdc_gamma_b(
k_neighbors=args.cdc_k_neighbors,
k_bandwidth=args.cdc_k_bandwidth,
d_cdc=args.cdc_d_cdc,
gamma=args.cdc_gamma,
force_recache=args.force_recache_cdc,
accelerator=accelerator,
debug=getattr(args, 'cdc_debug', False),
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
)
if self.cdc_config_hash is None:
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
else:
self.cdc_config_hash = None
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
text_encoding_strategy = self.get_text_encoding_strategy(args)
@@ -660,6 +681,19 @@ class NetworkTrainer:
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# Load CDC-FM Γ_b dataset if enabled
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_config_hash is not None:
from library.cdc_fm import GammaBDataset
logger.info(f"CDC Γ_b dataset ready (hash: {self.cdc_config_hash})")
self.gamma_b_dataset = GammaBDataset(
device="cuda" if torch.cuda.is_available() else "cpu",
config_hash=self.cdc_config_hash
)
else:
self.gamma_b_dataset = None
# prepare network
net_kwargs = {}
if args.network_args is not None: