mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
774 lines
29 KiB
Python
774 lines
29 KiB
Python
import logging
|
||
import torch
|
||
import numpy as np
|
||
import faiss # type: ignore
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
from safetensors.torch import save_file
|
||
from typing import List, Dict, Optional, Union, Tuple
|
||
from dataclasses import dataclass
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class LatentSample:
|
||
"""
|
||
Container for a single latent with metadata
|
||
"""
|
||
latent: np.ndarray # (d,) flattened latent vector
|
||
global_idx: int # Global index in dataset
|
||
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
|
||
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
|
||
|
||
|
||
class CarreDuChampComputer:
|
||
"""
|
||
Core CDC-FM computation - agnostic to data source
|
||
Just handles the math for a batch of same-size latents
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
k_neighbors: int = 256,
|
||
k_bandwidth: int = 8,
|
||
d_cdc: int = 8,
|
||
gamma: float = 1.0,
|
||
device: str = 'cuda'
|
||
):
|
||
self.k = k_neighbors
|
||
self.k_bw = k_bandwidth
|
||
self.d_cdc = d_cdc
|
||
self.gamma = gamma
|
||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||
|
||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
Build k-NN graph using FAISS
|
||
|
||
Args:
|
||
latents_np: (N, d) numpy array of same-dimensional latents
|
||
|
||
Returns:
|
||
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
|
||
indices: (N, k_actual+1) neighbor indices
|
||
"""
|
||
N, d = latents_np.shape
|
||
|
||
# Clamp k to available neighbors (can't have more neighbors than samples)
|
||
k_actual = min(self.k, N - 1)
|
||
|
||
# Ensure float32
|
||
if latents_np.dtype != np.float32:
|
||
latents_np = latents_np.astype(np.float32)
|
||
|
||
# Build FAISS index
|
||
index = faiss.IndexFlatL2(d)
|
||
|
||
if torch.cuda.is_available():
|
||
res = faiss.StandardGpuResources()
|
||
index = faiss.index_cpu_to_gpu(res, 0, index)
|
||
|
||
index.add(latents_np) # type: ignore
|
||
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
|
||
|
||
return distances, indices
|
||
|
||
@torch.no_grad()
|
||
def compute_gamma_b_single(
|
||
self,
|
||
point_idx: int,
|
||
latents_np: np.ndarray,
|
||
distances: np.ndarray,
|
||
indices: np.ndarray,
|
||
epsilon: np.ndarray
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Compute Γ_b for a single point
|
||
|
||
Args:
|
||
point_idx: Index of point to process
|
||
latents_np: (N, d) all latents in this batch
|
||
distances: (N, k+1) precomputed distances
|
||
indices: (N, k+1) precomputed neighbor indices
|
||
epsilon: (N,) bandwidth per point
|
||
|
||
Returns:
|
||
eigenvectors: (d_cdc, d) as half precision tensor
|
||
eigenvalues: (d_cdc,) as half precision tensor
|
||
"""
|
||
d = latents_np.shape[1]
|
||
|
||
# Get neighbors (exclude self)
|
||
neighbor_idx = indices[point_idx, 1:] # (k,)
|
||
neighbor_points = latents_np[neighbor_idx] # (k, d)
|
||
|
||
# Clamp distances to prevent overflow (max realistic L2 distance)
|
||
MAX_DISTANCE = 1e10
|
||
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
|
||
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
|
||
|
||
# Compute Gaussian kernel weights with numerical guards
|
||
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
|
||
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
|
||
|
||
# Compute denominator with guard against overflow
|
||
denom = eps_i * eps_neighbors
|
||
denom = np.maximum(denom, 1e-20) # Additional guard
|
||
|
||
# Compute weights with safe exponential
|
||
exp_arg = -neighbor_dists_sq / denom
|
||
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
|
||
weights = np.exp(exp_arg)
|
||
|
||
# Normalize weights, handle edge case of all zeros
|
||
weight_sum = weights.sum()
|
||
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
|
||
# Fallback to uniform weights
|
||
weights = np.ones_like(weights) / len(weights)
|
||
else:
|
||
weights = weights / weight_sum
|
||
|
||
# Compute local mean
|
||
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
|
||
|
||
# Center and weight for SVD
|
||
centered = neighbor_points - m_star
|
||
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
|
||
|
||
# Validate input is finite before SVD
|
||
if not np.all(np.isfinite(weighted_centered)):
|
||
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
|
||
# Fallback: use uniform weights and simple centering
|
||
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
|
||
m_star = np.mean(neighbor_points, axis=0)
|
||
centered = neighbor_points - m_star
|
||
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
|
||
|
||
# Move to GPU for SVD (100x speedup!)
|
||
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
|
||
self.device, dtype=torch.float32
|
||
)
|
||
|
||
try:
|
||
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
|
||
except RuntimeError as e:
|
||
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
|
||
try:
|
||
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
|
||
U = torch.from_numpy(U).to(self.device)
|
||
S = torch.from_numpy(S).to(self.device)
|
||
Vh = torch.from_numpy(Vh).to(self.device)
|
||
except np.linalg.LinAlgError as e2:
|
||
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
|
||
# Return zero eigenvalues/vectors as fallback
|
||
return (
|
||
torch.zeros(self.d_cdc, d, dtype=torch.float16),
|
||
torch.zeros(self.d_cdc, dtype=torch.float16)
|
||
)
|
||
|
||
# Eigenvalues of Γ_b
|
||
eigenvalues_full = S ** 2
|
||
|
||
# Keep top d_cdc
|
||
if len(eigenvalues_full) >= self.d_cdc:
|
||
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
|
||
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
|
||
else:
|
||
# Pad if k < d_cdc
|
||
top_eigenvalues = eigenvalues_full
|
||
top_eigenvectors = Vh
|
||
if len(eigenvalues_full) < self.d_cdc:
|
||
pad_size = self.d_cdc - len(eigenvalues_full)
|
||
top_eigenvalues = torch.cat([
|
||
top_eigenvalues,
|
||
torch.zeros(pad_size, device=self.device)
|
||
])
|
||
top_eigenvectors = torch.cat([
|
||
top_eigenvectors,
|
||
torch.zeros(pad_size, d, device=self.device)
|
||
])
|
||
|
||
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
|
||
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
|
||
# Then apply gamma: γc_i Γ̂(x^(i))
|
||
#
|
||
# Our implementation:
|
||
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
|
||
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
|
||
# 3. Clamp for numerical stability
|
||
#
|
||
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
|
||
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
|
||
|
||
# Step 1: Normalize by the maximum eigenvalue to get relative scales
|
||
# This is the paper's 1/λ_1^i normalization factor
|
||
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
|
||
|
||
if max_eigenval > 1e-10:
|
||
# Scale so max eigenvalue = 1.0, preserving relative ratios
|
||
top_eigenvalues = top_eigenvalues / max_eigenval
|
||
|
||
# Step 2: Apply gamma and clamp to safe range
|
||
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
|
||
# Clamping ensures numerical stability and prevents extreme values
|
||
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
|
||
|
||
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
|
||
# fp16 range: 6e-5 to 65,504, our values are well within this
|
||
eigenvectors_fp16 = top_eigenvectors.cpu().half()
|
||
eigenvalues_fp16 = top_eigenvalues.cpu().half()
|
||
|
||
# Cleanup
|
||
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
return eigenvectors_fp16, eigenvalues_fp16
|
||
|
||
def compute_for_batch(
|
||
self,
|
||
latents_np: np.ndarray,
|
||
global_indices: List[int]
|
||
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
|
||
"""
|
||
Compute Γ_b for all points in a batch of same-size latents
|
||
|
||
Args:
|
||
latents_np: (N, d) numpy array
|
||
global_indices: List of global dataset indices for each latent
|
||
|
||
Returns:
|
||
Dict mapping global_idx -> (eigenvectors, eigenvalues)
|
||
"""
|
||
N, d = latents_np.shape
|
||
|
||
# Validate inputs
|
||
if len(global_indices) != N:
|
||
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
|
||
|
||
print(f"Computing CDC for batch: {N} samples, dim={d}")
|
||
|
||
# Handle small sample cases - require minimum samples for meaningful k-NN
|
||
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
|
||
|
||
if N < MIN_SAMPLES_FOR_CDC:
|
||
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
|
||
results = {}
|
||
for local_idx in range(N):
|
||
global_idx = global_indices[local_idx]
|
||
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
|
||
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
|
||
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
|
||
results[global_idx] = (eigvecs, eigvals)
|
||
return results
|
||
|
||
# Step 1: Build k-NN graph
|
||
print(" Building k-NN graph...")
|
||
distances, indices = self.compute_knn_graph(latents_np)
|
||
|
||
# Step 2: Compute bandwidth
|
||
# Use min to handle case where k_bw >= actual neighbors returned
|
||
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
|
||
epsilon = distances[:, k_bw_actual]
|
||
|
||
# Step 3: Compute Γ_b for each point
|
||
results = {}
|
||
print(" Computing Γ_b for each point...")
|
||
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
|
||
global_idx = global_indices[local_idx]
|
||
eigvecs, eigvals = self.compute_gamma_b_single(
|
||
local_idx, latents_np, distances, indices, epsilon
|
||
)
|
||
results[global_idx] = (eigvecs, eigvals)
|
||
|
||
return results
|
||
|
||
|
||
class LatentBatcher:
|
||
"""
|
||
Collects variable-size latents and batches them by size
|
||
"""
|
||
|
||
def __init__(self, size_tolerance: float = 0.0):
|
||
"""
|
||
Args:
|
||
size_tolerance: If > 0, group latents within tolerance % of size
|
||
If 0, only exact size matches are batched
|
||
"""
|
||
self.size_tolerance = size_tolerance
|
||
self.samples: List[LatentSample] = []
|
||
|
||
def add_sample(self, sample: LatentSample):
|
||
"""Add a single latent sample"""
|
||
self.samples.append(sample)
|
||
|
||
def add_latent(
|
||
self,
|
||
latent: Union[np.ndarray, torch.Tensor],
|
||
global_idx: int,
|
||
shape: Optional[Tuple[int, ...]] = None,
|
||
metadata: Optional[Dict] = None
|
||
):
|
||
"""
|
||
Add a latent vector with automatic shape tracking
|
||
|
||
Args:
|
||
latent: Latent vector (any shape, will be flattened)
|
||
global_idx: Global index in dataset
|
||
shape: Original shape (if None, uses latent.shape)
|
||
metadata: Optional metadata dict
|
||
"""
|
||
# Convert to numpy and flatten
|
||
if isinstance(latent, torch.Tensor):
|
||
latent_np = latent.cpu().numpy()
|
||
else:
|
||
latent_np = latent
|
||
|
||
original_shape = shape if shape is not None else latent_np.shape
|
||
latent_flat = latent_np.flatten()
|
||
|
||
sample = LatentSample(
|
||
latent=latent_flat,
|
||
global_idx=global_idx,
|
||
shape=original_shape,
|
||
metadata=metadata
|
||
)
|
||
|
||
self.add_sample(sample)
|
||
|
||
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
|
||
"""
|
||
Group samples by exact shape to avoid resizing distortion.
|
||
|
||
Each bucket contains only samples with identical latent dimensions.
|
||
Buckets with fewer than k_neighbors samples will be skipped during CDC
|
||
computation and fall back to standard Gaussian noise.
|
||
|
||
Returns:
|
||
Dict mapping exact_shape -> list of samples with that shape
|
||
"""
|
||
batches = {}
|
||
|
||
for sample in self.samples:
|
||
shape_key = sample.shape
|
||
|
||
# Group by exact shape only - no aspect ratio grouping or resizing
|
||
if shape_key not in batches:
|
||
batches[shape_key] = []
|
||
|
||
batches[shape_key].append(sample)
|
||
|
||
return batches
|
||
|
||
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
|
||
"""
|
||
Get aspect ratio category for grouping.
|
||
Groups images by aspect ratio bins to ensure sufficient samples.
|
||
|
||
For shape (C, H, W), computes aspect ratio H/W and bins it.
|
||
"""
|
||
if len(shape) < 3:
|
||
return "unknown"
|
||
|
||
# Extract spatial dimensions (H, W)
|
||
h, w = shape[-2], shape[-1]
|
||
aspect_ratio = h / w
|
||
|
||
# Define aspect ratio bins (±15% tolerance)
|
||
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
|
||
bins = [
|
||
(0.5, 0.65, "9:16"), # Portrait tall
|
||
(0.65, 0.85, "3:4"), # Portrait
|
||
(0.85, 1.15, "1:1"), # Square
|
||
(1.15, 1.50, "4:3"), # Landscape
|
||
(1.50, 2.0, "16:9"), # Landscape wide
|
||
(2.0, 3.0, "21:9"), # Ultra wide
|
||
]
|
||
|
||
for min_ratio, max_ratio, label in bins:
|
||
if min_ratio <= aspect_ratio < max_ratio:
|
||
return label
|
||
|
||
# Fallback for extreme ratios
|
||
if aspect_ratio < 0.5:
|
||
return "ultra_tall"
|
||
else:
|
||
return "ultra_wide"
|
||
|
||
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
|
||
"""Check if two shapes are within tolerance"""
|
||
if len(shape1) != len(shape2):
|
||
return False
|
||
|
||
size1 = np.prod(shape1)
|
||
size2 = np.prod(shape2)
|
||
|
||
ratio = abs(size1 - size2) / max(size1, size2)
|
||
return ratio <= self.size_tolerance
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
|
||
class CDCPreprocessor:
|
||
"""
|
||
High-level CDC preprocessing coordinator
|
||
Handles variable-size latents by batching and delegating to CarreDuChampComputer
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
k_neighbors: int = 256,
|
||
k_bandwidth: int = 8,
|
||
d_cdc: int = 8,
|
||
gamma: float = 1.0,
|
||
device: str = 'cuda',
|
||
size_tolerance: float = 0.0,
|
||
debug: bool = False,
|
||
adaptive_k: bool = False,
|
||
min_bucket_size: int = 16
|
||
):
|
||
self.computer = CarreDuChampComputer(
|
||
k_neighbors=k_neighbors,
|
||
k_bandwidth=k_bandwidth,
|
||
d_cdc=d_cdc,
|
||
gamma=gamma,
|
||
device=device
|
||
)
|
||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||
self.debug = debug
|
||
self.adaptive_k = adaptive_k
|
||
self.min_bucket_size = min_bucket_size
|
||
|
||
def add_latent(
|
||
self,
|
||
latent: Union[np.ndarray, torch.Tensor],
|
||
global_idx: int,
|
||
shape: Optional[Tuple[int, ...]] = None,
|
||
metadata: Optional[Dict] = None
|
||
):
|
||
"""
|
||
Add a single latent to the preprocessing queue
|
||
|
||
Args:
|
||
latent: Latent vector (will be flattened)
|
||
global_idx: Global dataset index
|
||
shape: Original shape (C, H, W)
|
||
metadata: Optional metadata
|
||
"""
|
||
self.batcher.add_latent(latent, global_idx, shape, metadata)
|
||
|
||
def compute_all(self, save_path: Union[str, Path]) -> Path:
|
||
"""
|
||
Compute Γ_b for all added latents and save to safetensors
|
||
|
||
Args:
|
||
save_path: Path to save the results
|
||
|
||
Returns:
|
||
Path to saved file
|
||
"""
|
||
save_path = Path(save_path)
|
||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Get batches by exact size (no resizing)
|
||
batches = self.batcher.get_batches()
|
||
|
||
# Count samples that will get CDC vs fallback
|
||
k_neighbors = self.computer.k
|
||
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
|
||
|
||
if self.adaptive_k:
|
||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
|
||
else:
|
||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||
|
||
if self.debug:
|
||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||
if self.adaptive_k:
|
||
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
|
||
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||
else:
|
||
mode = "adaptive" if self.adaptive_k else "fixed"
|
||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||
|
||
# Storage for results
|
||
all_results = {}
|
||
|
||
# Process each bucket with progress bar
|
||
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
|
||
|
||
for shape, samples in bucket_iter:
|
||
num_samples = len(samples)
|
||
|
||
if self.debug:
|
||
print(f"\n{'='*60}")
|
||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||
print(f"{'='*60}")
|
||
|
||
# Determine effective k for this bucket
|
||
if self.adaptive_k:
|
||
# Adaptive mode: skip if below minimum, otherwise use best available k
|
||
if num_samples < min_threshold:
|
||
if self.debug:
|
||
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
|
||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||
|
||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||
C, H, W = shape
|
||
d = C * H * W
|
||
|
||
for sample in samples:
|
||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||
|
||
continue
|
||
|
||
# Use adaptive k for this bucket
|
||
k_effective = min(k_neighbors, num_samples - 1)
|
||
else:
|
||
# Fixed mode: skip if below k_neighbors
|
||
if num_samples < k_neighbors:
|
||
if self.debug:
|
||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||
|
||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||
C, H, W = shape
|
||
d = C * H * W
|
||
|
||
for sample in samples:
|
||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||
|
||
continue
|
||
|
||
k_effective = k_neighbors
|
||
|
||
# Collect latents (no resizing needed - all same shape)
|
||
latents_list = []
|
||
global_indices = []
|
||
|
||
for sample in samples:
|
||
global_indices.append(sample.global_idx)
|
||
latents_list.append(sample.latent) # Already flattened
|
||
|
||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||
|
||
# Compute CDC for this batch with effective k
|
||
if self.debug:
|
||
if self.adaptive_k and k_effective < k_neighbors:
|
||
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
|
||
else:
|
||
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
|
||
|
||
# Temporarily override k for this bucket
|
||
original_k = self.computer.k
|
||
self.computer.k = k_effective
|
||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||
self.computer.k = original_k
|
||
|
||
# No resizing needed - eigenvectors are already correct size
|
||
if self.debug:
|
||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||
|
||
# Merge into overall results
|
||
all_results.update(batch_results)
|
||
|
||
# Save to safetensors
|
||
if self.debug:
|
||
print(f"\n{'='*60}")
|
||
print("Saving results...")
|
||
print(f"{'='*60}")
|
||
|
||
tensors_dict = {
|
||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||
'metadata/k_neighbors': torch.tensor([self.computer.k]),
|
||
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
|
||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||
}
|
||
|
||
# Add shape information and CDC results for each sample
|
||
# Use image_key as the identifier
|
||
for sample in self.batcher.samples:
|
||
image_key = sample.metadata['image_key']
|
||
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
|
||
|
||
# Get CDC results for this sample
|
||
if sample.global_idx in all_results:
|
||
eigvecs, eigvals = all_results[sample.global_idx]
|
||
|
||
# Convert numpy arrays to torch tensors
|
||
if isinstance(eigvecs, np.ndarray):
|
||
eigvecs = torch.from_numpy(eigvecs)
|
||
if isinstance(eigvals, np.ndarray):
|
||
eigvals = torch.from_numpy(eigvals)
|
||
|
||
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
|
||
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
|
||
|
||
save_file(tensors_dict, save_path)
|
||
|
||
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
|
||
logger.info(f"Saved to {save_path}")
|
||
logger.info(f"File size: {file_size_gb:.2f} GB")
|
||
|
||
return save_path
|
||
|
||
|
||
class GammaBDataset:
|
||
"""
|
||
Efficient loader for Γ_b matrices during training
|
||
Handles variable-size latents
|
||
"""
|
||
|
||
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
|
||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||
self.gamma_b_path = Path(gamma_b_path)
|
||
|
||
# Load metadata
|
||
logger.info(f"Loading Γ_b from {gamma_b_path}...")
|
||
from safetensors import safe_open
|
||
|
||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
|
||
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
|
||
|
||
# Cache all shapes in memory to avoid repeated I/O during training
|
||
# Loading once at init is much faster than opening the file every training step
|
||
self.shapes_cache = {}
|
||
# Get all shape keys (they're stored as shapes/{image_key})
|
||
all_keys = f.keys()
|
||
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
|
||
for shape_key in shape_keys:
|
||
image_key = shape_key.replace('shapes/', '')
|
||
shape_tensor = f.get_tensor(shape_key)
|
||
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
|
||
|
||
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||
|
||
@torch.no_grad()
|
||
def get_gamma_b_sqrt(
|
||
self,
|
||
image_keys: Union[List[str], List],
|
||
device: Optional[str] = None
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Get Γ_b^(1/2) components for a batch of image_keys
|
||
|
||
Args:
|
||
image_keys: List of image_key strings
|
||
device: Device to load to (defaults to self.device)
|
||
|
||
Returns:
|
||
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
|
||
eigenvalues: (B, d_cdc)
|
||
"""
|
||
if device is None:
|
||
device = self.device
|
||
|
||
# Load from safetensors
|
||
from safetensors import safe_open
|
||
|
||
eigenvectors_list = []
|
||
eigenvalues_list = []
|
||
|
||
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
||
for image_key in image_keys:
|
||
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
|
||
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
|
||
|
||
eigenvectors_list.append(eigvecs)
|
||
eigenvalues_list.append(eigvals)
|
||
|
||
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
|
||
# Check if all eigenvectors have the same dimension
|
||
dims = [ev.shape[1] for ev in eigenvectors_list]
|
||
if len(set(dims)) > 1:
|
||
# Dimension mismatch! This shouldn't happen with proper bucketing
|
||
# but can occur if batch contains mixed sizes
|
||
raise RuntimeError(
|
||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||
f"Image keys: {image_keys}. "
|
||
f"This means the training batch contains images of different sizes, "
|
||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||
f"Check that your dataloader buckets are configured correctly."
|
||
)
|
||
|
||
eigenvectors = torch.stack(eigenvectors_list, dim=0)
|
||
eigenvalues = torch.stack(eigenvalues_list, dim=0)
|
||
|
||
return eigenvectors, eigenvalues
|
||
|
||
def get_shape(self, image_key: str) -> Tuple[int, ...]:
|
||
"""Get the original shape for a sample (cached in memory)"""
|
||
return self.shapes_cache[image_key]
|
||
|
||
def compute_sigma_t_x(
|
||
self,
|
||
eigenvectors: torch.Tensor,
|
||
eigenvalues: torch.Tensor,
|
||
x: torch.Tensor,
|
||
t: Union[float, torch.Tensor]
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
|
||
|
||
Args:
|
||
eigenvectors: (B, d_cdc, d)
|
||
eigenvalues: (B, d_cdc)
|
||
x: (B, d) or (B, C, H, W) - will be flattened if needed
|
||
t: (B,) or scalar time
|
||
|
||
Returns:
|
||
result: Same shape as input x
|
||
|
||
Note:
|
||
Gradients flow through this function for backprop during training.
|
||
"""
|
||
# Store original shape to restore later
|
||
orig_shape = x.shape
|
||
|
||
# Flatten x if it's 4D
|
||
if x.dim() == 4:
|
||
B, C, H, W = x.shape
|
||
x = x.reshape(B, -1) # (B, C*H*W)
|
||
|
||
if not isinstance(t, torch.Tensor):
|
||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||
|
||
if t.dim() == 0:
|
||
t = t.expand(x.shape[0])
|
||
|
||
t = t.view(-1, 1)
|
||
|
||
# Early return for t=0 to avoid numerical errors
|
||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8):
|
||
return x.reshape(orig_shape)
|
||
|
||
# Check if CDC is disabled (all eigenvalues are zero)
|
||
# This happens for buckets with < k_neighbors samples
|
||
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
|
||
# Fallback to standard Gaussian noise (no CDC correction)
|
||
return x.reshape(orig_shape)
|
||
|
||
# Γ_b^(1/2) @ x using low-rank representation
|
||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||
|
||
# Σ_t @ x
|
||
result = (1 - t) * x + t * gamma_sqrt_x
|
||
|
||
# Restore original shape
|
||
result = result.reshape(orig_shape)
|
||
|
||
return result
|