mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add CDC-FM (Carré du Champ Flow Matching) support
Implements geometry-aware noise generation for FLUX training based on arXiv:2510.05930v1.
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -36,6 +34,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.model_type: Optional[str] = None
|
||||
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -327,9 +326,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
# Get CDC parameters if enabled
|
||||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "indices" in batch) else None
|
||||
batch_indices = batch.get("indices") if gamma_b_dataset is not None else None
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||||
gamma_b_dataset=gamma_b_dataset, batch_indices=batch_indices
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
@@ -494,7 +499,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
logger.info("T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
@@ -533,6 +538,49 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
||||
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
||||
)
|
||||
|
||||
# CDC-FM arguments
|
||||
parser.add_argument(
|
||||
"--use_cdc_fm",
|
||||
action="store_true",
|
||||
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
|
||||
" / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_k_neighbors",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
|
||||
" / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_k_bandwidth",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
|
||||
" / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_d_cdc",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Dimension of CDC subspace (default: 8)"
|
||||
" / CDCサブ空間の次元(デフォルト: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_gamma",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="CDC strength parameter (default: 1.0)"
|
||||
" / CDC強度パラメータ(デフォルト: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_recache_cdc",
|
||||
action="store_true",
|
||||
help="Force recompute CDC cache even if valid cache exists"
|
||||
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
712
library/cdc_fm.py
Normal file
712
library/cdc_fm.py
Normal file
@@ -0,0 +1,712 @@
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import faiss # type: ignore
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from safetensors.torch import save_file
|
||||
from typing import List, Dict, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatentSample:
|
||||
"""
|
||||
Container for a single latent with metadata
|
||||
"""
|
||||
latent: np.ndarray # (d,) flattened latent vector
|
||||
global_idx: int # Global index in dataset
|
||||
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
|
||||
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
|
||||
|
||||
|
||||
class CarreDuChampComputer:
|
||||
"""
|
||||
Core CDC-FM computation - agnostic to data source
|
||||
Just handles the math for a batch of same-size latents
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda'
|
||||
):
|
||||
self.k = k_neighbors
|
||||
self.k_bw = k_bandwidth
|
||||
self.d_cdc = d_cdc
|
||||
self.gamma = gamma
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Build k-NN graph using FAISS
|
||||
|
||||
Args:
|
||||
latents_np: (N, d) numpy array of same-dimensional latents
|
||||
|
||||
Returns:
|
||||
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
|
||||
indices: (N, k_actual+1) neighbor indices
|
||||
"""
|
||||
N, d = latents_np.shape
|
||||
|
||||
# Clamp k to available neighbors (can't have more neighbors than samples)
|
||||
k_actual = min(self.k, N - 1)
|
||||
|
||||
# Ensure float32
|
||||
if latents_np.dtype != np.float32:
|
||||
latents_np = latents_np.astype(np.float32)
|
||||
|
||||
# Build FAISS index
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
res = faiss.StandardGpuResources()
|
||||
index = faiss.index_cpu_to_gpu(res, 0, index)
|
||||
|
||||
index.add(latents_np) # type: ignore
|
||||
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
|
||||
|
||||
return distances, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_gamma_b_single(
|
||||
self,
|
||||
point_idx: int,
|
||||
latents_np: np.ndarray,
|
||||
distances: np.ndarray,
|
||||
indices: np.ndarray,
|
||||
epsilon: np.ndarray
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute Γ_b for a single point
|
||||
|
||||
Args:
|
||||
point_idx: Index of point to process
|
||||
latents_np: (N, d) all latents in this batch
|
||||
distances: (N, k+1) precomputed distances
|
||||
indices: (N, k+1) precomputed neighbor indices
|
||||
epsilon: (N,) bandwidth per point
|
||||
|
||||
Returns:
|
||||
eigenvectors: (d_cdc, d) as half precision tensor
|
||||
eigenvalues: (d_cdc,) as half precision tensor
|
||||
"""
|
||||
d = latents_np.shape[1]
|
||||
|
||||
# Get neighbors (exclude self)
|
||||
neighbor_idx = indices[point_idx, 1:] # (k,)
|
||||
neighbor_points = latents_np[neighbor_idx] # (k, d)
|
||||
|
||||
# Clamp distances to prevent overflow (max realistic L2 distance)
|
||||
MAX_DISTANCE = 1e10
|
||||
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
|
||||
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
|
||||
|
||||
# Compute Gaussian kernel weights with numerical guards
|
||||
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
|
||||
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
|
||||
|
||||
# Compute denominator with guard against overflow
|
||||
denom = eps_i * eps_neighbors
|
||||
denom = np.maximum(denom, 1e-20) # Additional guard
|
||||
|
||||
# Compute weights with safe exponential
|
||||
exp_arg = -neighbor_dists_sq / denom
|
||||
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
|
||||
weights = np.exp(exp_arg)
|
||||
|
||||
# Normalize weights, handle edge case of all zeros
|
||||
weight_sum = weights.sum()
|
||||
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
|
||||
# Fallback to uniform weights
|
||||
weights = np.ones_like(weights) / len(weights)
|
||||
else:
|
||||
weights = weights / weight_sum
|
||||
|
||||
# Compute local mean
|
||||
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
|
||||
|
||||
# Center and weight for SVD
|
||||
centered = neighbor_points - m_star
|
||||
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
|
||||
|
||||
# Validate input is finite before SVD
|
||||
if not np.all(np.isfinite(weighted_centered)):
|
||||
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
|
||||
# Fallback: use uniform weights and simple centering
|
||||
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
|
||||
m_star = np.mean(neighbor_points, axis=0)
|
||||
centered = neighbor_points - m_star
|
||||
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
|
||||
|
||||
# Move to GPU for SVD (100x speedup!)
|
||||
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
|
||||
self.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
try:
|
||||
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
|
||||
try:
|
||||
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
|
||||
U = torch.from_numpy(U).to(self.device)
|
||||
S = torch.from_numpy(S).to(self.device)
|
||||
Vh = torch.from_numpy(Vh).to(self.device)
|
||||
except np.linalg.LinAlgError as e2:
|
||||
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
|
||||
# Return zero eigenvalues/vectors as fallback
|
||||
return (
|
||||
torch.zeros(self.d_cdc, d, dtype=torch.float16),
|
||||
torch.zeros(self.d_cdc, dtype=torch.float16)
|
||||
)
|
||||
|
||||
# Eigenvalues of Γ_b
|
||||
eigenvalues_full = S ** 2
|
||||
|
||||
# Keep top d_cdc
|
||||
if len(eigenvalues_full) >= self.d_cdc:
|
||||
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
|
||||
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
|
||||
else:
|
||||
# Pad if k < d_cdc
|
||||
top_eigenvalues = eigenvalues_full
|
||||
top_eigenvectors = Vh
|
||||
if len(eigenvalues_full) < self.d_cdc:
|
||||
pad_size = self.d_cdc - len(eigenvalues_full)
|
||||
top_eigenvalues = torch.cat([
|
||||
top_eigenvalues,
|
||||
torch.zeros(pad_size, device=self.device)
|
||||
])
|
||||
top_eigenvectors = torch.cat([
|
||||
top_eigenvectors,
|
||||
torch.zeros(pad_size, d, device=self.device)
|
||||
])
|
||||
|
||||
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
|
||||
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
|
||||
# Then apply gamma: γc_i Γ̂(x^(i))
|
||||
#
|
||||
# Our implementation:
|
||||
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
|
||||
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
|
||||
# 3. Clamp for numerical stability
|
||||
#
|
||||
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
|
||||
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
|
||||
|
||||
# Step 1: Normalize by the maximum eigenvalue to get relative scales
|
||||
# This is the paper's 1/λ_1^i normalization factor
|
||||
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
|
||||
|
||||
if max_eigenval > 1e-10:
|
||||
# Scale so max eigenvalue = 1.0, preserving relative ratios
|
||||
top_eigenvalues = top_eigenvalues / max_eigenval
|
||||
|
||||
# Step 2: Apply gamma and clamp to safe range
|
||||
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
|
||||
# Clamping ensures numerical stability and prevents extreme values
|
||||
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
|
||||
|
||||
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
|
||||
# fp16 range: 6e-5 to 65,504, our values are well within this
|
||||
eigenvectors_fp16 = top_eigenvectors.cpu().half()
|
||||
eigenvalues_fp16 = top_eigenvalues.cpu().half()
|
||||
|
||||
# Cleanup
|
||||
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return eigenvectors_fp16, eigenvalues_fp16
|
||||
|
||||
def compute_for_batch(
|
||||
self,
|
||||
latents_np: np.ndarray,
|
||||
global_indices: List[int]
|
||||
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Compute Γ_b for all points in a batch of same-size latents
|
||||
|
||||
Args:
|
||||
latents_np: (N, d) numpy array
|
||||
global_indices: List of global dataset indices for each latent
|
||||
|
||||
Returns:
|
||||
Dict mapping global_idx -> (eigenvectors, eigenvalues)
|
||||
"""
|
||||
N, d = latents_np.shape
|
||||
|
||||
# Validate inputs
|
||||
if len(global_indices) != N:
|
||||
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
|
||||
|
||||
print(f"Computing CDC for batch: {N} samples, dim={d}")
|
||||
|
||||
# Handle small sample cases - require minimum samples for meaningful k-NN
|
||||
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
|
||||
|
||||
if N < MIN_SAMPLES_FOR_CDC:
|
||||
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
|
||||
results = {}
|
||||
for local_idx in range(N):
|
||||
global_idx = global_indices[local_idx]
|
||||
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
|
||||
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
|
||||
results[global_idx] = (eigvecs, eigvals)
|
||||
return results
|
||||
|
||||
# Step 1: Build k-NN graph
|
||||
print(" Building k-NN graph...")
|
||||
distances, indices = self.compute_knn_graph(latents_np)
|
||||
|
||||
# Step 2: Compute bandwidth
|
||||
# Use min to handle case where k_bw >= actual neighbors returned
|
||||
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
|
||||
epsilon = distances[:, k_bw_actual]
|
||||
|
||||
# Step 3: Compute Γ_b for each point
|
||||
results = {}
|
||||
print(" Computing Γ_b for each point...")
|
||||
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
|
||||
global_idx = global_indices[local_idx]
|
||||
eigvecs, eigvals = self.compute_gamma_b_single(
|
||||
local_idx, latents_np, distances, indices, epsilon
|
||||
)
|
||||
results[global_idx] = (eigvecs, eigvals)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class LatentBatcher:
|
||||
"""
|
||||
Collects variable-size latents and batches them by size
|
||||
"""
|
||||
|
||||
def __init__(self, size_tolerance: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
size_tolerance: If > 0, group latents within tolerance % of size
|
||||
If 0, only exact size matches are batched
|
||||
"""
|
||||
self.size_tolerance = size_tolerance
|
||||
self.samples: List[LatentSample] = []
|
||||
|
||||
def add_sample(self, sample: LatentSample):
|
||||
"""Add a single latent sample"""
|
||||
self.samples.append(sample)
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a latent vector with automatic shape tracking
|
||||
|
||||
Args:
|
||||
latent: Latent vector (any shape, will be flattened)
|
||||
global_idx: Global index in dataset
|
||||
shape: Original shape (if None, uses latent.shape)
|
||||
metadata: Optional metadata dict
|
||||
"""
|
||||
# Convert to numpy and flatten
|
||||
if isinstance(latent, torch.Tensor):
|
||||
latent_np = latent.cpu().numpy()
|
||||
else:
|
||||
latent_np = latent
|
||||
|
||||
original_shape = shape if shape is not None else latent_np.shape
|
||||
latent_flat = latent_np.flatten()
|
||||
|
||||
sample = LatentSample(
|
||||
latent=latent_flat,
|
||||
global_idx=global_idx,
|
||||
shape=original_shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
self.add_sample(sample)
|
||||
|
||||
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
|
||||
"""
|
||||
Group samples by exact shape to avoid resizing distortion.
|
||||
|
||||
Each bucket contains only samples with identical latent dimensions.
|
||||
Buckets with fewer than k_neighbors samples will be skipped during CDC
|
||||
computation and fall back to standard Gaussian noise.
|
||||
|
||||
Returns:
|
||||
Dict mapping exact_shape -> list of samples with that shape
|
||||
"""
|
||||
batches = {}
|
||||
|
||||
for sample in self.samples:
|
||||
shape_key = sample.shape
|
||||
|
||||
# Group by exact shape only - no aspect ratio grouping or resizing
|
||||
if shape_key not in batches:
|
||||
batches[shape_key] = []
|
||||
|
||||
batches[shape_key].append(sample)
|
||||
|
||||
return batches
|
||||
|
||||
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
|
||||
"""
|
||||
Get aspect ratio category for grouping.
|
||||
Groups images by aspect ratio bins to ensure sufficient samples.
|
||||
|
||||
For shape (C, H, W), computes aspect ratio H/W and bins it.
|
||||
"""
|
||||
if len(shape) < 3:
|
||||
return "unknown"
|
||||
|
||||
# Extract spatial dimensions (H, W)
|
||||
h, w = shape[-2], shape[-1]
|
||||
aspect_ratio = h / w
|
||||
|
||||
# Define aspect ratio bins (±15% tolerance)
|
||||
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
|
||||
bins = [
|
||||
(0.5, 0.65, "9:16"), # Portrait tall
|
||||
(0.65, 0.85, "3:4"), # Portrait
|
||||
(0.85, 1.15, "1:1"), # Square
|
||||
(1.15, 1.50, "4:3"), # Landscape
|
||||
(1.50, 2.0, "16:9"), # Landscape wide
|
||||
(2.0, 3.0, "21:9"), # Ultra wide
|
||||
]
|
||||
|
||||
for min_ratio, max_ratio, label in bins:
|
||||
if min_ratio <= aspect_ratio < max_ratio:
|
||||
return label
|
||||
|
||||
# Fallback for extreme ratios
|
||||
if aspect_ratio < 0.5:
|
||||
return "ultra_tall"
|
||||
else:
|
||||
return "ultra_wide"
|
||||
|
||||
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
|
||||
"""Check if two shapes are within tolerance"""
|
||||
if len(shape1) != len(shape2):
|
||||
return False
|
||||
|
||||
size1 = np.prod(shape1)
|
||||
size2 = np.prod(shape2)
|
||||
|
||||
ratio = abs(size1 - size2) / max(size1, size2)
|
||||
return ratio <= self.size_tolerance
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
class CDCPreprocessor:
|
||||
"""
|
||||
High-level CDC preprocessing coordinator
|
||||
Handles variable-size latents by batching and delegating to CarreDuChampComputer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0
|
||||
):
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
d_cdc=d_cdc,
|
||||
gamma=gamma,
|
||||
device=device
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
global_idx: int,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Add a single latent to the preprocessing queue
|
||||
|
||||
Args:
|
||||
latent: Latent vector (will be flattened)
|
||||
global_idx: Global dataset index
|
||||
shape: Original shape (C, H, W)
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
self.batcher.add_latent(latent, global_idx, shape, metadata)
|
||||
|
||||
def compute_all(self, save_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Compute Γ_b for all added latents and save to safetensors
|
||||
|
||||
Args:
|
||||
save_path: Path to save the results
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get batches by exact size (no resizing)
|
||||
batches = self.batcher.get_batches()
|
||||
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
|
||||
# Process each bucket
|
||||
for shape, samples in batches.items():
|
||||
num_samples = len(samples)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if bucket has enough samples for k-NN
|
||||
if num_samples < k_neighbors:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
# Collect latents (no resizing needed - all same shape)
|
||||
latents_list = []
|
||||
global_indices = []
|
||||
|
||||
for sample in samples:
|
||||
global_indices.append(sample.global_idx)
|
||||
latents_list.append(sample.latent) # Already flattened
|
||||
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||||
|
||||
# Merge into overall results
|
||||
all_results.update(batch_results)
|
||||
|
||||
# Save to safetensors
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving results...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
tensors_dict = {
|
||||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||||
'metadata/k_neighbors': torch.tensor([self.computer.k]),
|
||||
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
|
||||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||||
}
|
||||
|
||||
# Add shape information for each sample
|
||||
for sample in self.batcher.samples:
|
||||
idx = sample.global_idx
|
||||
tensors_dict[f'shapes/{idx}'] = torch.tensor(sample.shape)
|
||||
|
||||
# Add CDC results (convert numpy to torch tensors)
|
||||
for global_idx, (eigvecs, eigvals) in all_results.items():
|
||||
# Convert numpy arrays to torch tensors
|
||||
if isinstance(eigvecs, np.ndarray):
|
||||
eigvecs = torch.from_numpy(eigvecs)
|
||||
if isinstance(eigvals, np.ndarray):
|
||||
eigvals = torch.from_numpy(eigvals)
|
||||
|
||||
tensors_dict[f'eigenvectors/{global_idx}'] = eigvecs
|
||||
tensors_dict[f'eigenvalues/{global_idx}'] = eigvals
|
||||
|
||||
save_file(tensors_dict, save_path)
|
||||
|
||||
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
|
||||
print(f"\nSaved to {save_path}")
|
||||
print(f"File size: {file_size_gb:.2f} GB")
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
class GammaBDataset:
|
||||
"""
|
||||
Efficient loader for Γ_b matrices during training
|
||||
Handles variable-size latents
|
||||
"""
|
||||
|
||||
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.gamma_b_path = Path(gamma_b_path)
|
||||
|
||||
# Load metadata
|
||||
print(f"Loading Γ_b from {gamma_b_path}...")
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||||
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
|
||||
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
|
||||
|
||||
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||||
|
||||
@torch.no_grad()
|
||||
def get_gamma_b_sqrt(
|
||||
self,
|
||||
indices: Union[List[int], np.ndarray, torch.Tensor],
|
||||
device: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get Γ_b^(1/2) components for a batch of indices
|
||||
|
||||
Args:
|
||||
indices: Sample indices
|
||||
device: Device to load to (defaults to self.device)
|
||||
|
||||
Returns:
|
||||
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
|
||||
eigenvalues: (B, d_cdc)
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
# Convert indices to list
|
||||
if isinstance(indices, torch.Tensor):
|
||||
indices = indices.cpu().numpy().tolist()
|
||||
elif isinstance(indices, np.ndarray):
|
||||
indices = indices.tolist()
|
||||
|
||||
# Load from safetensors
|
||||
from safetensors import safe_open
|
||||
|
||||
eigenvectors_list = []
|
||||
eigenvalues_list = []
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
eigvecs = f.get_tensor(f'eigenvectors/{idx}').float()
|
||||
eigvals = f.get_tensor(f'eigenvalues/{idx}').float()
|
||||
|
||||
eigenvectors_list.append(eigvecs)
|
||||
eigenvalues_list.append(eigvals)
|
||||
|
||||
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
|
||||
# Check if all eigenvectors have the same dimension
|
||||
dims = [ev.shape[1] for ev in eigenvectors_list]
|
||||
if len(set(dims)) > 1:
|
||||
# Dimension mismatch! This shouldn't happen with proper bucketing
|
||||
# but can occur if batch contains mixed sizes
|
||||
raise RuntimeError(
|
||||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||||
f"Batch indices: {indices}. "
|
||||
f"This means the training batch contains images of different sizes, "
|
||||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||||
f"Check that your dataloader buckets are configured correctly."
|
||||
)
|
||||
|
||||
eigenvectors = torch.stack(eigenvectors_list, dim=0)
|
||||
eigenvalues = torch.stack(eigenvalues_list, dim=0)
|
||||
|
||||
return eigenvectors, eigenvalues
|
||||
|
||||
def get_shape(self, idx: int) -> Tuple[int, ...]:
|
||||
"""Get the original shape for a sample"""
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||||
shape_tensor = f.get_tensor(f'shapes/{idx}')
|
||||
return tuple(shape_tensor.numpy().tolist())
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
eigenvalues: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
t: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
|
||||
|
||||
Args:
|
||||
eigenvectors: (B, d_cdc, d)
|
||||
eigenvalues: (B, d_cdc)
|
||||
x: (B, d) or (B, C, H, W) - will be flattened if needed
|
||||
t: (B,) or scalar time
|
||||
|
||||
Returns:
|
||||
result: Same shape as input x
|
||||
"""
|
||||
# Store original shape to restore later
|
||||
orig_shape = x.shape
|
||||
|
||||
# Flatten x if it's 4D
|
||||
if x.dim() == 4:
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, -1) # (B, C*H*W)
|
||||
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||||
|
||||
if t.dim() == 0:
|
||||
t = t.expand(x.shape[0])
|
||||
|
||||
t = t.view(-1, 1)
|
||||
|
||||
# Early return for t=0 to avoid numerical errors
|
||||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8):
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Check if CDC is disabled (all eigenvalues are zero)
|
||||
# This happens for buckets with < k_neighbors samples
|
||||
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
|
||||
# Fallback to standard Gaussian noise (no CDC correction)
|
||||
return x.reshape(orig_shape)
|
||||
|
||||
# Γ_b^(1/2) @ x using low-rank representation
|
||||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||||
|
||||
# Σ_t @ x
|
||||
result = (1 - t) * x + t * gamma_sqrt_x
|
||||
|
||||
# Restore original shape
|
||||
result = result.reshape(orig_shape)
|
||||
|
||||
return result
|
||||
@@ -2,10 +2,8 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator, PartialState
|
||||
@@ -183,7 +181,7 @@ def sample_image_inference(
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info("negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
@@ -469,8 +467,16 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype,
|
||||
gamma_b_dataset=None, batch_indices=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps for training.
|
||||
|
||||
Args:
|
||||
gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise
|
||||
batch_indices: Optional batch indices for CDC-FM (required if gamma_b_dataset provided)
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
@@ -514,6 +520,44 @@ def get_noisy_model_input_and_timesteps(
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Apply CDC-FM geometry-aware noise transformation if enabled
|
||||
if gamma_b_dataset is not None and batch_indices is not None:
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps / num_timesteps
|
||||
|
||||
# Process each sample individually to handle potential dimension mismatches
|
||||
# (can happen with multi-subset training where bucketing differs between preprocessing and training)
|
||||
B, C, H, W = noise.shape
|
||||
noise_transformed = []
|
||||
|
||||
for i in range(B):
|
||||
idx = batch_indices[i].item() if isinstance(batch_indices[i], torch.Tensor) else batch_indices[i]
|
||||
|
||||
# Get cached shape for this sample
|
||||
cached_shape = gamma_b_dataset.get_shape(idx)
|
||||
current_shape = (C, H, W)
|
||||
|
||||
if cached_shape != current_shape:
|
||||
# Shape mismatch - sample was bucketed differently between preprocessing and training
|
||||
# Use standard Gaussian noise for this sample (no CDC)
|
||||
logger.warning(
|
||||
f"CDC shape mismatch for sample {idx}: "
|
||||
f"cached {cached_shape} vs current {current_shape}. "
|
||||
f"Using Gaussian noise (no CDC)."
|
||||
)
|
||||
noise_transformed.append(noise[i])
|
||||
else:
|
||||
# Shapes match - apply CDC transformation
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([idx], device=device)
|
||||
|
||||
noise_flat = noise[i].reshape(1, -1)
|
||||
t_single = t_normalized[i:i+1] if t_normalized.dim() > 0 else t_normalized
|
||||
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_single)
|
||||
noise_transformed.append(noise_cdc_flat.reshape(C, H, W))
|
||||
|
||||
noise = torch.stack(noise_transformed, dim=0)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
|
||||
@@ -1569,11 +1569,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
indices = [] # CDC-FM: track global dataset indices
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
# CDC-FM: Get global index for this image
|
||||
# Create a sorted list of keys to ensure deterministic indexing
|
||||
if not hasattr(self, '_image_key_to_index'):
|
||||
self._image_key_to_index = {key: idx for idx, key in enumerate(sorted(self.image_data.keys()))}
|
||||
global_idx = self._image_key_to_index[image_key]
|
||||
indices.append(global_idx)
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
@@ -1819,6 +1827,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
# CDC-FM: Add global indices to batch
|
||||
example["indices"] = torch.LongTensor(indices)
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -2690,6 +2701,127 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.new_cache_text_encoder_outputs(models, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def cache_cdc_gamma_b(
|
||||
self,
|
||||
cdc_output_path: str,
|
||||
k_neighbors: int = 256,
|
||||
k_bandwidth: int = 8,
|
||||
d_cdc: int = 8,
|
||||
gamma: float = 1.0,
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
|
||||
Args:
|
||||
cdc_output_path: Path to save cdc_gamma_b.safetensors
|
||||
k_neighbors: k-NN neighbors
|
||||
k_bandwidth: Bandwidth estimation neighbors
|
||||
d_cdc: CDC subspace dimension
|
||||
gamma: CDC strength
|
||||
force_recache: Force recompute even if cache exists
|
||||
accelerator: For multi-GPU support
|
||||
|
||||
Returns:
|
||||
Path to cached CDC file
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
cdc_path = Path(cdc_output_path)
|
||||
|
||||
# Check if valid cache exists
|
||||
if cdc_path.exists() and not force_recache:
|
||||
if self._is_cdc_cache_valid(cdc_path, k_neighbors, d_cdc, gamma):
|
||||
logger.info(f"Valid CDC cache found at {cdc_path}, skipping preprocessing")
|
||||
return str(cdc_path)
|
||||
else:
|
||||
logger.info(f"CDC cache found but invalid, will recompute")
|
||||
|
||||
# Only main process computes CDC
|
||||
is_main = accelerator is None or accelerator.is_main_process
|
||||
if not is_main:
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
return str(cdc_path)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Starting CDC-FM preprocessing")
|
||||
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize CDC preprocessor
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
from library.strategy_base import LatentsCachingStrategy
|
||||
|
||||
caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
# Collect all latents from all datasets
|
||||
for dataset_idx, dataset in enumerate(self.datasets):
|
||||
logger.info(f"Loading latents from dataset {dataset_idx}...")
|
||||
image_infos = list(dataset.image_data.values())
|
||||
|
||||
for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")):
|
||||
# Load latent from disk or memory
|
||||
if info.latents is not None:
|
||||
latent = info.latents
|
||||
elif info.latents_npz is not None:
|
||||
# Load from disk
|
||||
latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso)
|
||||
if latent is None:
|
||||
logger.warning(f"Failed to load latent from {info.latents_npz}, skipping")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"No latent found for {info.absolute_path}, skipping")
|
||||
continue
|
||||
|
||||
# Add to preprocessor (with unique global index across all datasets)
|
||||
actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx
|
||||
preprocessor.add_latent(latent=latent, global_idx=actual_global_idx, shape=latent.shape, metadata={"image_key": info.image_key})
|
||||
|
||||
# Compute and save
|
||||
logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...")
|
||||
preprocessor.compute_all(save_path=cdc_path)
|
||||
|
||||
if accelerator is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return str(cdc_path)
|
||||
|
||||
def _is_cdc_cache_valid(self, cdc_path: "pathlib.Path", k_neighbors: int, d_cdc: int, gamma: float) -> bool:
|
||||
"""Check if CDC cache has matching hyperparameters"""
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(cdc_path), framework="pt", device="cpu") as f:
|
||||
cached_k = int(f.get_tensor("metadata/k_neighbors").item())
|
||||
cached_d = int(f.get_tensor("metadata/d_cdc").item())
|
||||
cached_gamma = float(f.get_tensor("metadata/gamma").item())
|
||||
cached_num = int(f.get_tensor("metadata/num_samples").item())
|
||||
|
||||
expected_num = sum(len(d.image_data) for d in self.datasets)
|
||||
|
||||
valid = cached_k == k_neighbors and cached_d == d_cdc and abs(cached_gamma - gamma) < 1e-6 and cached_num == expected_num
|
||||
|
||||
if not valid:
|
||||
logger.info(
|
||||
f"Cache mismatch: k={cached_k} (expected {k_neighbors}), "
|
||||
f"d_cdc={cached_d} (expected {d_cdc}), "
|
||||
f"gamma={cached_gamma} (expected {gamma}), "
|
||||
f"num={cached_num} (expected {expected_num})"
|
||||
)
|
||||
|
||||
return valid
|
||||
except Exception as e:
|
||||
logger.warning(f"Error validating CDC cache: {e}")
|
||||
return False
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
|
||||
242
tests/library/test_cdc_eigenvalue_scaling.py
Normal file
242
tests/library/test_cdc_eigenvalue_scaling.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Tests to verify CDC eigenvalue scaling is correct.
|
||||
|
||||
These tests ensure eigenvalues are properly scaled to prevent training loss explosion.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
|
||||
class TestEigenvalueScaling:
|
||||
"""Test that eigenvalues are properly scaled to reasonable ranges"""
|
||||
|
||||
def test_eigenvalues_in_correct_range(self, tmp_path):
|
||||
"""Verify eigenvalues are scaled to ~0.01-1.0 range, not millions"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Add deterministic latents with structured patterns
|
||||
for i in range(10):
|
||||
# Create gradient pattern: values from 0 to 2.0 across spatial dims
|
||||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 32.0 # Range [0, 2.0]
|
||||
# Add per-sample variation
|
||||
latent = latent + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify eigenvalues are in correct range
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
|
||||
# Filter out zero eigenvalues (from padding when k < d_cdc)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# Critical assertions for eigenvalue scale
|
||||
assert all_eigvals.max() < 10.0, f"Max eigenvalue {all_eigvals.max():.2e} is too large (should be <10)"
|
||||
assert len(non_zero_eigvals) > 0, "Should have some non-zero eigenvalues"
|
||||
assert np.mean(non_zero_eigvals) < 2.0, f"Mean eigenvalue {np.mean(non_zero_eigvals):.2e} is too large"
|
||||
|
||||
# Check sqrt (used in noise) is reasonable
|
||||
sqrt_max = np.sqrt(all_eigvals.max())
|
||||
assert sqrt_max < 5.0, f"sqrt(max eigenvalue) = {sqrt_max:.2f} will cause noise explosion"
|
||||
|
||||
print(f"\n✓ Eigenvalue range: [{all_eigvals.min():.4f}, {all_eigvals.max():.4f}]")
|
||||
print(f"✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||||
print(f"✓ Mean (non-zero): {np.mean(non_zero_eigvals):.4f}")
|
||||
print(f"✓ sqrt(max): {sqrt_max:.4f}")
|
||||
|
||||
def test_eigenvalues_not_all_zero(self, tmp_path):
|
||||
"""Ensure eigenvalues are not all zero (indicating computation failure)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h * 4 + w) / 32.0 + i * 0.2
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
all_eigvals = []
|
||||
for i in range(10):
|
||||
eigvals = f.get_tensor(f"eigenvalues/{i}").numpy()
|
||||
all_eigvals.extend(eigvals)
|
||||
|
||||
all_eigvals = np.array(all_eigvals)
|
||||
non_zero_eigvals = all_eigvals[all_eigvals > 1e-6]
|
||||
|
||||
# With clamping, eigenvalues will be in range [1e-3, gamma*1.0]
|
||||
# Check that we have some non-zero eigenvalues
|
||||
assert len(non_zero_eigvals) > 0, "All eigenvalues are zero - computation failed"
|
||||
|
||||
# Check they're in the expected clamped range
|
||||
assert np.all(non_zero_eigvals >= 1e-3), f"Some eigenvalues below clamp min: {np.min(non_zero_eigvals)}"
|
||||
assert np.all(non_zero_eigvals <= 1.0), f"Some eigenvalues above clamp max: {np.max(non_zero_eigvals)}"
|
||||
|
||||
print(f"\n✓ Non-zero eigenvalues: {len(non_zero_eigvals)}/{len(all_eigvals)}")
|
||||
print(f"✓ Range: [{np.min(non_zero_eigvals):.4f}, {np.max(non_zero_eigvals):.4f}]")
|
||||
print(f"✓ Mean: {np.mean(non_zero_eigvals):.4f}")
|
||||
|
||||
def test_fp16_storage_no_overflow(self, tmp_path):
|
||||
"""Verify fp16 storage doesn't overflow (max fp16 = 65,504)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern with higher magnitude
|
||||
latent = torch.zeros(16, 8, 8, dtype=torch.float32)
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent[:, h, w] = (h * 8 + w) / 16.0 # Range [0, 4.0]
|
||||
latent = latent + i * 0.3
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check dtype is fp16
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
|
||||
assert eigvecs.dtype == torch.float16, f"Expected fp16, got {eigvecs.dtype}"
|
||||
assert eigvals.dtype == torch.float16, f"Expected fp16, got {eigvals.dtype}"
|
||||
|
||||
# Check no values near fp16 max (would indicate overflow)
|
||||
FP16_MAX = 65504
|
||||
max_eigval = eigvals.max().item()
|
||||
|
||||
assert max_eigval < 100, (
|
||||
f"Eigenvalue {max_eigval:.2e} is suspiciously large for fp16 storage. "
|
||||
f"May indicate overflow (fp16 max = {FP16_MAX})"
|
||||
)
|
||||
|
||||
print(f"\n✓ Storage dtype: {eigvals.dtype}")
|
||||
print(f"✓ Max eigenvalue: {max_eigval:.4f} (safe for fp16)")
|
||||
|
||||
def test_latent_magnitude_preserved(self, tmp_path):
|
||||
"""Verify latent magnitude is preserved (no unwanted normalization)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Store original latents with deterministic patterns
|
||||
original_latents = []
|
||||
for i in range(10):
|
||||
# Create structured pattern with known magnitude
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) + i * 0.5
|
||||
original_latents.append(latent.clone())
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
# Compute original latent statistics
|
||||
orig_std = torch.stack(original_latents).std().item()
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# The stored latents should preserve original magnitude
|
||||
stored_latents_std = np.std([s.latent for s in preprocessor.batcher.samples])
|
||||
|
||||
# Should be similar to original (within 20% due to potential batching effects)
|
||||
assert 0.8 * orig_std < stored_latents_std < 1.2 * orig_std, (
|
||||
f"Stored latent std {stored_latents_std:.2f} differs too much from "
|
||||
f"original {orig_std:.2f}. Latent magnitude was not preserved."
|
||||
)
|
||||
|
||||
print(f"\n✓ Original latent std: {orig_std:.2f}")
|
||||
print(f"✓ Stored latent std: {stored_latents_std:.2f}")
|
||||
|
||||
|
||||
class TestTrainingLossScale:
|
||||
"""Test that eigenvalues produce reasonable loss magnitudes"""
|
||||
|
||||
def test_noise_magnitude_reasonable(self, tmp_path):
|
||||
"""Verify CDC noise has reasonable magnitude for training"""
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
# Create CDC cache with deterministic data
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Create deterministic pattern
|
||||
latent = torch.zeros(16, 4, 4, dtype=torch.float32)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent[c, h, w] = (c + h + w) / 20.0 + i * 0.1
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Load and compute noise
|
||||
gamma_b = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
# Simulate training scenario with deterministic data
|
||||
batch_size = 3
|
||||
latents = torch.zeros(batch_size, 16, 4, 4)
|
||||
for b in range(batch_size):
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latents[b, c, h, w] = (b + c + h + w) / 24.0
|
||||
t = torch.tensor([0.5, 0.7, 0.9]) # Different timesteps
|
||||
indices = [0, 5, 9]
|
||||
|
||||
eigvecs, eigvals = gamma_b.get_gamma_b_sqrt(indices)
|
||||
noise = gamma_b.compute_sigma_t_x(eigvecs, eigvals, latents, t)
|
||||
|
||||
# Check noise magnitude
|
||||
noise_std = noise.std().item()
|
||||
latent_std = latents.std().item()
|
||||
|
||||
# Noise should be similar magnitude to input latents (within 10x)
|
||||
ratio = noise_std / latent_std
|
||||
assert 0.1 < ratio < 10.0, (
|
||||
f"Noise std ({noise_std:.3f}) vs latent std ({latent_std:.3f}) "
|
||||
f"ratio {ratio:.2f} is too extreme. Will cause training instability."
|
||||
)
|
||||
|
||||
# Simulated MSE loss should be reasonable
|
||||
simulated_loss = torch.mean((noise - latents) ** 2).item()
|
||||
assert simulated_loss < 100.0, (
|
||||
f"Simulated MSE loss {simulated_loss:.2f} is too high. "
|
||||
f"Should be O(0.1-1.0) for stable training."
|
||||
)
|
||||
|
||||
print(f"\n✓ Noise/latent ratio: {ratio:.2f}")
|
||||
print(f"✓ Simulated MSE loss: {simulated_loss:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
164
tests/library/test_cdc_interpolation_comparison.py
Normal file
164
tests/library/test_cdc_interpolation_comparison.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Test comparing interpolation vs pad/truncate for CDC preprocessing.
|
||||
|
||||
This test quantifies the difference between the two approaches.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class TestInterpolationComparison:
|
||||
"""Compare interpolation vs pad/truncate"""
|
||||
|
||||
def test_intermediate_representation_quality(self):
|
||||
"""Compare intermediate representation quality for CDC computation"""
|
||||
# Create test latents with different sizes - deterministic
|
||||
latent_small = torch.zeros(16, 4, 4)
|
||||
for c in range(16):
|
||||
for h in range(4):
|
||||
for w in range(4):
|
||||
latent_small[c, h, w] = (c * 0.1 + h * 0.2 + w * 0.3) / 3.0
|
||||
|
||||
latent_large = torch.zeros(16, 8, 8)
|
||||
for c in range(16):
|
||||
for h in range(8):
|
||||
for w in range(8):
|
||||
latent_large[c, h, w] = (c * 0.1 + h * 0.15 + w * 0.15) / 3.0
|
||||
|
||||
target_h, target_w = 6, 6 # Median size
|
||||
|
||||
# Method 1: Interpolation
|
||||
def interpolate_method(latent, target_h, target_w):
|
||||
latent_input = latent.unsqueeze(0) # (1, C, H, W)
|
||||
latent_resized = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
)
|
||||
# Resize back
|
||||
C, H, W = latent.shape
|
||||
latent_reconstructed = F.interpolate(
|
||||
latent_resized, size=(H, W), mode='bilinear', align_corners=False
|
||||
)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent_input)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent_input)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Method 2: Pad/Truncate
|
||||
def pad_truncate_method(latent, target_h, target_w):
|
||||
C, H, W = latent.shape
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
current_dim = C * H * W
|
||||
|
||||
if current_dim == target_dim:
|
||||
latent_resized_flat = latent_flat
|
||||
elif current_dim > target_dim:
|
||||
# Truncate
|
||||
latent_resized_flat = latent_flat[:target_dim]
|
||||
else:
|
||||
# Pad
|
||||
latent_resized_flat = torch.zeros(target_dim)
|
||||
latent_resized_flat[:current_dim] = latent_flat
|
||||
|
||||
# Resize back
|
||||
if current_dim == target_dim:
|
||||
latent_reconstructed_flat = latent_resized_flat
|
||||
elif current_dim > target_dim:
|
||||
# Pad back
|
||||
latent_reconstructed_flat = torch.zeros(current_dim)
|
||||
latent_reconstructed_flat[:target_dim] = latent_resized_flat
|
||||
else:
|
||||
# Truncate back
|
||||
latent_reconstructed_flat = latent_resized_flat[:current_dim]
|
||||
|
||||
latent_reconstructed = latent_reconstructed_flat.reshape(C, H, W)
|
||||
error = torch.mean(torch.abs(latent_reconstructed - latent)).item()
|
||||
relative_error = error / (torch.mean(torch.abs(latent)).item() + 1e-8)
|
||||
return relative_error
|
||||
|
||||
# Compare for small latent (needs padding)
|
||||
interp_error_small = interpolate_method(latent_small, target_h, target_w)
|
||||
pad_error_small = pad_truncate_method(latent_small, target_h, target_w)
|
||||
|
||||
# Compare for large latent (needs truncation)
|
||||
interp_error_large = interpolate_method(latent_large, target_h, target_w)
|
||||
truncate_error_large = pad_truncate_method(latent_large, target_h, target_w)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Reconstruction Error Comparison")
|
||||
print("=" * 60)
|
||||
print(f"\nSmall latent (16x4x4 -> 16x6x6 -> 16x4x4):")
|
||||
print(f" Interpolation error: {interp_error_small:.6f}")
|
||||
print(f" Pad/truncate error: {pad_error_small:.6f}")
|
||||
if pad_error_small > 0:
|
||||
print(f" Improvement: {(pad_error_small - interp_error_small) / pad_error_small * 100:.2f}%")
|
||||
else:
|
||||
print(f" Note: Pad/truncate has 0 reconstruction error (perfect recovery)")
|
||||
print(f" BUT the intermediate representation is corrupted with zeros!")
|
||||
|
||||
print(f"\nLarge latent (16x8x8 -> 16x6x6 -> 16x8x8):")
|
||||
print(f" Interpolation error: {interp_error_large:.6f}")
|
||||
print(f" Pad/truncate error: {truncate_error_large:.6f}")
|
||||
if truncate_error_large > 0:
|
||||
print(f" Improvement: {(truncate_error_large - interp_error_large) / truncate_error_large * 100:.2f}%")
|
||||
|
||||
# The key insight: Reconstruction error is NOT what matters for CDC!
|
||||
# What matters is the INTERMEDIATE representation quality used for geometry estimation.
|
||||
# Pad/truncate may have good reconstruction, but the intermediate is corrupted.
|
||||
|
||||
print("\nKey insight: For CDC, intermediate representation quality matters,")
|
||||
print("not reconstruction error. Interpolation preserves spatial structure.")
|
||||
|
||||
# Verify interpolation errors are reasonable
|
||||
assert interp_error_small < 1.0, "Interpolation should have reasonable error"
|
||||
assert interp_error_large < 1.0, "Interpolation should have reasonable error"
|
||||
|
||||
def test_spatial_structure_preservation(self):
|
||||
"""Test that interpolation preserves spatial structure better than pad/truncate"""
|
||||
# Create a latent with clear spatial pattern (gradient)
|
||||
C, H, W = 16, 4, 4
|
||||
latent = torch.zeros(C, H, W)
|
||||
for i in range(H):
|
||||
for j in range(W):
|
||||
latent[:, i, j] = i * W + j # Gradient pattern
|
||||
|
||||
target_h, target_w = 6, 6
|
||||
|
||||
# Interpolation
|
||||
latent_input = latent.unsqueeze(0)
|
||||
latent_interp = F.interpolate(
|
||||
latent_input, size=(target_h, target_w), mode='bilinear', align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Pad/truncate
|
||||
latent_flat = latent.reshape(-1)
|
||||
target_dim = C * target_h * target_w
|
||||
latent_padded = torch.zeros(target_dim)
|
||||
latent_padded[:len(latent_flat)] = latent_flat
|
||||
latent_pad = latent_padded.reshape(C, target_h, target_w)
|
||||
|
||||
# Check gradient preservation
|
||||
# For interpolation, adjacent pixels should have smooth gradients
|
||||
grad_x_interp = torch.abs(latent_interp[:, :, 1:] - latent_interp[:, :, :-1]).mean()
|
||||
grad_y_interp = torch.abs(latent_interp[:, 1:, :] - latent_interp[:, :-1, :]).mean()
|
||||
|
||||
# For padding, there will be abrupt changes (gradient to zero)
|
||||
grad_x_pad = torch.abs(latent_pad[:, :, 1:] - latent_pad[:, :, :-1]).mean()
|
||||
grad_y_pad = torch.abs(latent_pad[:, 1:, :] - latent_pad[:, :-1, :]).mean()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Spatial Structure Preservation")
|
||||
print("=" * 60)
|
||||
print(f"\nGradient smoothness (lower is smoother):")
|
||||
print(f" Interpolation - X gradient: {grad_x_interp:.4f}, Y gradient: {grad_y_interp:.4f}")
|
||||
print(f" Pad/truncate - X gradient: {grad_x_pad:.4f}, Y gradient: {grad_y_pad:.4f}")
|
||||
|
||||
# Padding introduces larger gradients due to abrupt zeros
|
||||
assert grad_x_pad > grad_x_interp, "Padding should introduce larger gradients"
|
||||
assert grad_y_pad > grad_y_interp, "Padding should introduce larger gradients"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
232
tests/library/test_cdc_standalone.py
Normal file
232
tests/library/test_cdc_standalone.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Standalone tests for CDC-FM integration.
|
||||
|
||||
These tests focus on CDC-FM specific functionality without importing
|
||||
the full training infrastructure that has problematic dependencies.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestCDCPreprocessor:
|
||||
"""Test CDC preprocessing functionality"""
|
||||
|
||||
def test_cdc_preprocessor_basic_workflow(self, tmp_path):
|
||||
"""Test basic CDC preprocessing with small dataset"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Add 10 small latents
|
||||
for i in range(10):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify file was created
|
||||
assert Path(result_path).exists()
|
||||
|
||||
# Verify structure
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
assert f.get_tensor("metadata/num_samples").item() == 10
|
||||
assert f.get_tensor("metadata/k_neighbors").item() == 5
|
||||
assert f.get_tensor("metadata/d_cdc").item() == 4
|
||||
|
||||
# Check first sample
|
||||
eigvecs = f.get_tensor("eigenvectors/0")
|
||||
eigvals = f.get_tensor("eigenvalues/0")
|
||||
|
||||
assert eigvecs.shape[0] == 4 # d_cdc
|
||||
assert eigvals.shape[0] == 4 # d_cdc
|
||||
|
||||
def test_cdc_preprocessor_different_shapes(self, tmp_path):
|
||||
"""Test CDC preprocessing with variable-size latents (bucketing)"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
# Add 5 latents of shape (16, 4, 4)
|
||||
for i in range(5):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
# Add 5 latents of different shape (16, 8, 8)
|
||||
for i in range(5, 10):
|
||||
latent = torch.randn(16, 8, 8, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
# Compute and save
|
||||
output_path = tmp_path / "test_gamma_b_multi.safetensors"
|
||||
result_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Verify both shape groups were processed
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(result_path), framework="pt", device="cpu") as f:
|
||||
# Check shapes are stored
|
||||
shape_0 = f.get_tensor("shapes/0")
|
||||
shape_5 = f.get_tensor("shapes/5")
|
||||
|
||||
assert tuple(shape_0.tolist()) == (16, 4, 4)
|
||||
assert tuple(shape_5.tolist()) == (16, 8, 8)
|
||||
|
||||
|
||||
class TestGammaBDataset:
|
||||
"""Test GammaBDataset loading and retrieval"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cdc_cache(self, tmp_path):
|
||||
"""Create a sample CDC cache file for testing"""
|
||||
cache_path = tmp_path / "test_gamma_b.safetensors"
|
||||
|
||||
# Create mock Γ_b data for 5 samples
|
||||
tensors = {
|
||||
"metadata/num_samples": torch.tensor([5]),
|
||||
"metadata/k_neighbors": torch.tensor([10]),
|
||||
"metadata/d_cdc": torch.tensor([4]),
|
||||
"metadata/gamma": torch.tensor([1.0]),
|
||||
}
|
||||
|
||||
# Add shape and CDC data for each sample
|
||||
for i in range(5):
|
||||
tensors[f"shapes/{i}"] = torch.tensor([16, 8, 8]) # C, H, W
|
||||
tensors[f"eigenvectors/{i}"] = torch.randn(4, 1024, dtype=torch.float32) # d_cdc x d
|
||||
tensors[f"eigenvalues/{i}"] = torch.rand(4, dtype=torch.float32) + 0.1 # positive
|
||||
|
||||
save_file(tensors, str(cache_path))
|
||||
return cache_path
|
||||
|
||||
def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache):
|
||||
"""Test that GammaBDataset loads metadata correctly"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
|
||||
assert gamma_b_dataset.num_samples == 5
|
||||
assert gamma_b_dataset.d_cdc == 4
|
||||
|
||||
def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache):
|
||||
"""Test retrieving Γ_b^(1/2) components"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
|
||||
# Get Γ_b for indices [0, 2, 4]
|
||||
indices = [0, 2, 4]
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(indices, device="cpu")
|
||||
|
||||
# Check shapes
|
||||
assert eigenvectors.shape == (3, 4, 1024) # (batch, d_cdc, d)
|
||||
assert eigenvalues.shape == (3, 4) # (batch, d_cdc)
|
||||
|
||||
# Check values are positive
|
||||
assert torch.all(eigenvalues > 0)
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns x unchanged at t=0"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
|
||||
# Create test latents (batch of 3, matching d=1024 flattened)
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.zeros(3) # t = 0 for all samples
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 1, 2], device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# At t=0, should return x unchanged
|
||||
assert torch.allclose(sigma_t_x, x, atol=1e-6)
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x returns correct shape"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
|
||||
x = torch.randn(2, 1024) # B, d (flattened)
|
||||
t = torch.tensor([0.3, 0.7])
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([1, 3], device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# Should return same shape as input
|
||||
assert sigma_t_x.shape == x.shape
|
||||
|
||||
def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache):
|
||||
"""Test compute_sigma_t_x produces finite values"""
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=sample_cdc_cache, device="cpu")
|
||||
|
||||
x = torch.randn(3, 1024) # B, d (flattened)
|
||||
t = torch.rand(3) # Random timesteps in [0, 1]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([0, 2, 4], device="cpu")
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
# Should not contain NaNs or Infs
|
||||
assert not torch.isnan(sigma_t_x).any()
|
||||
assert torch.isfinite(sigma_t_x).all()
|
||||
|
||||
|
||||
class TestCDCEndToEnd:
|
||||
"""End-to-end CDC workflow tests"""
|
||||
|
||||
def test_full_preprocessing_and_usage_workflow(self, tmp_path):
|
||||
"""Test complete workflow: preprocess -> save -> load -> use"""
|
||||
# Step 1: Preprocess latents
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
num_samples = 10
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(16, 4, 4, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=latent.shape)
|
||||
|
||||
output_path = tmp_path / "cdc_gamma_b.safetensors"
|
||||
cdc_path = preprocessor.compute_all(save_path=output_path)
|
||||
|
||||
# Step 2: Load with GammaBDataset
|
||||
gamma_b_dataset = GammaBDataset(gamma_b_path=cdc_path, device="cpu")
|
||||
|
||||
assert gamma_b_dataset.num_samples == num_samples
|
||||
|
||||
# Step 3: Use in mock training scenario
|
||||
batch_size = 3
|
||||
batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256)
|
||||
batch_t = torch.rand(batch_size)
|
||||
batch_indices = [0, 5, 9]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(batch_indices, device="cpu")
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
# Verify output is reasonable
|
||||
assert sigma_t_x.shape == batch_latents_flat.shape
|
||||
assert not torch.isnan(sigma_t_x).any()
|
||||
assert torch.isfinite(sigma_t_x).all()
|
||||
|
||||
# Verify that noise changes with different timesteps
|
||||
sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size))
|
||||
sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size))
|
||||
|
||||
# At t=0, should be close to x; at t=1, should be different
|
||||
assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6)
|
||||
assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -622,6 +622,23 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# CDC-FM preprocessing
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm:
|
||||
logger.info("CDC-FM enabled, preprocessing Γ_b matrices...")
|
||||
cdc_output_path = os.path.join(args.output_dir, "cdc_gamma_b.safetensors")
|
||||
|
||||
self.cdc_cache_path = train_dataset_group.cache_cdc_gamma_b(
|
||||
cdc_output_path=cdc_output_path,
|
||||
k_neighbors=args.cdc_k_neighbors,
|
||||
k_bandwidth=args.cdc_k_bandwidth,
|
||||
d_cdc=args.cdc_d_cdc,
|
||||
gamma=args.cdc_gamma,
|
||||
force_recache=args.force_recache_cdc,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
else:
|
||||
self.cdc_cache_path = None
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||
@@ -634,7 +651,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
if unet is None:
|
||||
if unet is none:
|
||||
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
|
||||
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
|
||||
|
||||
@@ -643,10 +660,10 @@ class NetworkTrainer:
|
||||
accelerator.print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not None:
|
||||
if args.base_weights is not none:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
@@ -660,6 +677,17 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# Load CDC-FM Γ_b dataset if enabled
|
||||
if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_cache_path is not None:
|
||||
from library.cdc_fm import GammaBDataset
|
||||
|
||||
logger.info(f"Loading CDC Γ_b dataset from {self.cdc_cache_path}")
|
||||
self.gamma_b_dataset = GammaBDataset(
|
||||
gamma_b_path=self.cdc_cache_path, device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
else:
|
||||
self.gamma_b_dataset = None
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
|
||||
Reference in New Issue
Block a user