Files
Kohya-ss-sd-scripts/library/cdc_fm.py
rockerBOO 8089cb6925 Improve dimension mismatch warning for CDC Flow Matching
- Add explicit warning and tracking for multiple unique latent shapes
- Simplify test imports by removing unused modules
- Minor formatting improvements in print statements
- Ensure log messages provide clear context about dimension mismatches
2025-10-11 17:17:09 -04:00

797 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
try:
import faiss # type: ignore
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class LatentSample:
"""
Container for a single latent with metadata
"""
latent: np.ndarray # (d,) flattened latent vector
global_idx: int # Global index in dataset
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
class CarreDuChampComputer:
"""
Core CDC-FM computation - agnostic to data source
Just handles the math for a batch of same-size latents
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda'
):
self.k = k_neighbors
self.k_bw = k_bandwidth
self.d_cdc = d_cdc
self.gamma = gamma
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using FAISS
Args:
latents_np: (N, d) numpy array of same-dimensional latents
Returns:
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
indices: (N, k_actual+1) neighbor indices
"""
N, d = latents_np.shape
# Clamp k to available neighbors (can't have more neighbors than samples)
k_actual = min(self.k, N - 1)
# Ensure float32
if latents_np.dtype != np.float32:
latents_np = latents_np.astype(np.float32)
# Build FAISS index
index = faiss.IndexFlatL2(d)
if torch.cuda.is_available():
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
index.add(latents_np) # type: ignore
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
return distances, indices
@torch.no_grad()
def compute_gamma_b_single(
self,
point_idx: int,
latents_np: np.ndarray,
distances: np.ndarray,
indices: np.ndarray,
epsilon: np.ndarray
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute Γ_b for a single point
Args:
point_idx: Index of point to process
latents_np: (N, d) all latents in this batch
distances: (N, k+1) precomputed distances
indices: (N, k+1) precomputed neighbor indices
epsilon: (N,) bandwidth per point
Returns:
eigenvectors: (d_cdc, d) as half precision tensor
eigenvalues: (d_cdc,) as half precision tensor
"""
d = latents_np.shape[1]
# Get neighbors (exclude self)
neighbor_idx = indices[point_idx, 1:] # (k,)
neighbor_points = latents_np[neighbor_idx] # (k, d)
# Clamp distances to prevent overflow (max realistic L2 distance)
MAX_DISTANCE = 1e10
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
# Compute Gaussian kernel weights with numerical guards
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
# Compute denominator with guard against overflow
denom = eps_i * eps_neighbors
denom = np.maximum(denom, 1e-20) # Additional guard
# Compute weights with safe exponential
exp_arg = -neighbor_dists_sq / denom
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
weights = np.exp(exp_arg)
# Normalize weights, handle edge case of all zeros
weight_sum = weights.sum()
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
# Fallback to uniform weights
weights = np.ones_like(weights) / len(weights)
else:
weights = weights / weight_sum
# Compute local mean
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
# Center and weight for SVD
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
# Validate input is finite before SVD
if not np.all(np.isfinite(weighted_centered)):
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
# Fallback: use uniform weights and simple centering
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
m_star = np.mean(neighbor_points, axis=0)
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
# Move to GPU for SVD
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
self.device, dtype=torch.float32
)
try:
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
except RuntimeError as e:
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
try:
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
U = torch.from_numpy(U).to(self.device)
S = torch.from_numpy(S).to(self.device)
Vh = torch.from_numpy(Vh).to(self.device)
except np.linalg.LinAlgError as e2:
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
# Return zero eigenvalues/vectors as fallback
return (
torch.zeros(self.d_cdc, d, dtype=torch.float16),
torch.zeros(self.d_cdc, dtype=torch.float16)
)
# Eigenvalues of Γ_b
eigenvalues_full = S ** 2
# Keep top d_cdc
if len(eigenvalues_full) >= self.d_cdc:
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
else:
# Pad if k < d_cdc
top_eigenvalues = eigenvalues_full
top_eigenvectors = Vh
if len(eigenvalues_full) < self.d_cdc:
pad_size = self.d_cdc - len(eigenvalues_full)
top_eigenvalues = torch.cat([
top_eigenvalues,
torch.zeros(pad_size, device=self.device)
])
top_eigenvectors = torch.cat([
top_eigenvectors,
torch.zeros(pad_size, d, device=self.device)
])
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
# Then apply gamma: γc_i Γ̂(x^(i))
#
# Our implementation:
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
# 3. Clamp for numerical stability
#
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
# Step 1: Normalize by the maximum eigenvalue to get relative scales
# This is the paper's 1/λ_1^i normalization factor
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
if max_eigenval > 1e-10:
# Scale so max eigenvalue = 1.0, preserving relative ratios
top_eigenvalues = top_eigenvalues / max_eigenval
# Step 2: Apply gamma and clamp to safe range
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
# Clamping ensures numerical stability and prevents extreme values
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
# fp16 range: 6e-5 to 65,504, our values are well within this
eigenvectors_fp16 = top_eigenvectors.cpu().half()
eigenvalues_fp16 = top_eigenvalues.cpu().half()
# Cleanup
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
if torch.cuda.is_available():
torch.cuda.empty_cache()
return eigenvectors_fp16, eigenvalues_fp16
def compute_for_batch(
self,
latents_np: np.ndarray,
global_indices: List[int]
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute Γ_b for all points in a batch of same-size latents
Args:
latents_np: (N, d) numpy array
global_indices: List of global dataset indices for each latent
Returns:
Dict mapping global_idx -> (eigenvectors, eigenvalues)
"""
N, d = latents_np.shape
# Validate inputs
if len(global_indices) != N:
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
print(f"Computing CDC for batch: {N} samples, dim={d}")
# Handle small sample cases - require minimum samples for meaningful k-NN
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
if N < MIN_SAMPLES_FOR_CDC:
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
results = {}
for local_idx in range(N):
global_idx = global_indices[local_idx]
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
results[global_idx] = (eigvecs, eigvals)
return results
# Step 1: Build k-NN graph
print(" Building k-NN graph...")
distances, indices = self.compute_knn_graph(latents_np)
# Step 2: Compute bandwidth
# Use min to handle case where k_bw >= actual neighbors returned
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
epsilon = distances[:, k_bw_actual]
# Step 3: Compute Γ_b for each point
results = {}
print(" Computing Γ_b for each point...")
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
global_idx = global_indices[local_idx]
eigvecs, eigvals = self.compute_gamma_b_single(
local_idx, latents_np, distances, indices, epsilon
)
results[global_idx] = (eigvecs, eigvals)
return results
class LatentBatcher:
"""
Collects variable-size latents and batches them by size
"""
def __init__(self, size_tolerance: float = 0.0):
"""
Args:
size_tolerance: If > 0, group latents within tolerance % of size
If 0, only exact size matches are batched
"""
self.size_tolerance = size_tolerance
self.samples: List[LatentSample] = []
def add_sample(self, sample: LatentSample):
"""Add a single latent sample"""
self.samples.append(sample)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a latent vector with automatic shape tracking
Args:
latent: Latent vector (any shape, will be flattened)
global_idx: Global index in dataset
shape: Original shape (if None, uses latent.shape)
metadata: Optional metadata dict
"""
# Convert to numpy and flatten
if isinstance(latent, torch.Tensor):
latent_np = latent.cpu().numpy()
else:
latent_np = latent
original_shape = shape if shape is not None else latent_np.shape
latent_flat = latent_np.flatten()
sample = LatentSample(
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
metadata=metadata
)
self.add_sample(sample)
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
"""
Group samples by exact shape to avoid resizing distortion.
Each bucket contains only samples with identical latent dimensions.
Buckets with fewer than k_neighbors samples will be skipped during CDC
computation and fall back to standard Gaussian noise.
Returns:
Dict mapping exact_shape -> list of samples with that shape
"""
batches = {}
shapes = set()
for sample in self.samples:
shape_key = sample.shape
shapes.add(shape_key)
# Group by exact shape only - no aspect ratio grouping or resizing
if shape_key not in batches:
batches[shape_key] = []
batches[shape_key].append(sample)
# If more than one unique shape, log a warning
if len(shapes) > 1:
logger.warning(
"Dimension mismatch: %d unique shapes detected. "
"Shapes: %s. Using Gaussian fallback for these samples.",
len(shapes),
shapes
)
return batches
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
"""
Get aspect ratio category for grouping.
Groups images by aspect ratio bins to ensure sufficient samples.
For shape (C, H, W), computes aspect ratio H/W and bins it.
"""
if len(shape) < 3:
return "unknown"
# Extract spatial dimensions (H, W)
h, w = shape[-2], shape[-1]
aspect_ratio = h / w
# Define aspect ratio bins (±15% tolerance)
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
bins = [
(0.5, 0.65, "9:16"), # Portrait tall
(0.65, 0.85, "3:4"), # Portrait
(0.85, 1.15, "1:1"), # Square
(1.15, 1.50, "4:3"), # Landscape
(1.50, 2.0, "16:9"), # Landscape wide
(2.0, 3.0, "21:9"), # Ultra wide
]
for min_ratio, max_ratio, label in bins:
if min_ratio <= aspect_ratio < max_ratio:
return label
# Fallback for extreme ratios
if aspect_ratio < 0.5:
return "ultra_tall"
else:
return "ultra_wide"
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
"""Check if two shapes are within tolerance"""
if len(shape1) != len(shape2):
return False
size1 = np.prod(shape1)
size2 = np.prod(shape2)
ratio = abs(size1 - size2) / max(size1, size2)
return ratio <= self.size_tolerance
def __len__(self):
return len(self.samples)
class CDCPreprocessor:
"""
High-level CDC preprocessing coordinator
Handles variable-size latents by batching and delegating to CarreDuChampComputer
"""
def __init__(
self,
k_neighbors: int = 256,
k_bandwidth: int = 8,
d_cdc: int = 8,
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16
):
if not FAISS_AVAILABLE:
raise ImportError(
"FAISS is required for CDC-FM but not installed. "
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). "
"CDC-FM will be disabled."
)
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
d_cdc=d_cdc,
gamma=gamma,
device=device
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
global_idx: int,
shape: Optional[Tuple[int, ...]] = None,
metadata: Optional[Dict] = None
):
"""
Add a single latent to the preprocessing queue
Args:
latent: Latent vector (will be flattened)
global_idx: Global dataset index
shape: Original shape (C, H, W)
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, shape, metadata)
def compute_all(self, save_path: Union[str, Path]) -> Path:
"""
Compute Γ_b for all added latents and save to safetensors
Args:
save_path: Path to save the results
Returns:
Path to saved file
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
if self.adaptive_k:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
else:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
if self.adaptive_k:
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
mode = "adaptive" if self.adaptive_k else "fixed"
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
# Process each bucket with progress bar
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
for shape, samples in bucket_iter:
num_samples = len(samples)
if self.debug:
print(f"\n{'='*60}")
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Determine effective k for this bucket
if self.adaptive_k:
# Adaptive mode: skip if below minimum, otherwise use best available k
if num_samples < min_threshold:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
# Use adaptive k for this bucket
k_effective = min(k_neighbors, num_samples - 1)
else:
# Fixed mode: skip if below k_neighbors
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
k_effective = k_neighbors
# Collect latents (no resizing needed - all same shape)
latents_list = []
global_indices = []
for sample in samples:
global_indices.append(sample.global_idx)
latents_list.append(sample.latent) # Already flattened
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch with effective k
if self.debug:
if self.adaptive_k and k_effective < k_neighbors:
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
else:
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
# Temporarily override k for this bucket
original_k = self.computer.k
self.computer.k = k_effective
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
self.computer.k = original_k
# No resizing needed - eigenvectors are already correct size
if self.debug:
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
# Merge into overall results
all_results.update(batch_results)
# Save to safetensors
if self.debug:
print(f"\n{'='*60}")
print("Saving results...")
print(f"{'='*60}")
tensors_dict = {
'metadata/num_samples': torch.tensor([len(all_results)]),
'metadata/k_neighbors': torch.tensor([self.computer.k]),
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
'metadata/gamma': torch.tensor([self.computer.gamma]),
}
# Add shape information and CDC results for each sample
# Use image_key as the identifier
for sample in self.batcher.samples:
image_key = sample.metadata['image_key']
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
# Get CDC results for this sample
if sample.global_idx in all_results:
eigvecs, eigvals = all_results[sample.global_idx]
# Convert numpy arrays to torch tensors
if isinstance(eigvecs, np.ndarray):
eigvecs = torch.from_numpy(eigvecs)
if isinstance(eigvals, np.ndarray):
eigvals = torch.from_numpy(eigvals)
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
save_file(tensors_dict, save_path)
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
logger.info(f"Saved to {save_path}")
logger.info(f"File size: {file_size_gb:.2f} GB")
return save_path
class GammaBDataset:
"""
Efficient loader for Γ_b matrices during training
Handles variable-size latents
"""
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.gamma_b_path = Path(gamma_b_path)
# Load metadata
logger.info(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
# Cache all shapes in memory to avoid repeated I/O during training
# Loading once at init is much faster than opening the file every training step
self.shapes_cache = {}
# Get all shape keys (they're stored as shapes/{image_key})
all_keys = f.keys()
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
for shape_key in shape_keys:
image_key = shape_key.replace('shapes/', '')
shape_tensor = f.get_tensor(shape_key)
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
image_keys: Union[List[str], List],
device: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of image_keys
Args:
image_keys: List of image_key strings
device: Device to load to (defaults to self.device)
Returns:
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
eigenvalues: (B, d_cdc)
"""
if device is None:
device = self.device
# Load from safetensors
from safetensors import safe_open
eigenvectors_list = []
eigenvalues_list = []
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
for image_key in image_keys:
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
eigenvectors_list.append(eigvecs)
eigenvalues_list.append(eigvals)
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
# Check if all eigenvectors have the same dimension
dims = [ev.shape[1] for ev in eigenvectors_list]
if len(set(dims)) > 1:
# Dimension mismatch! This shouldn't happen with proper bucketing
# but can occur if batch contains mixed sizes
raise RuntimeError(
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
f"Image keys: {image_keys}. "
f"This means the training batch contains images of different sizes, "
f"which violates CDC's requirement for uniform latent dimensions per batch. "
f"Check that your dataloader buckets are configured correctly."
)
eigenvectors = torch.stack(eigenvectors_list, dim=0)
eigenvalues = torch.stack(eigenvalues_list, dim=0)
return eigenvectors, eigenvalues
def get_shape(self, image_key: str) -> Tuple[int, ...]:
"""Get the original shape for a sample (cached in memory)"""
return self.shapes_cache[image_key]
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,
eigenvalues: torch.Tensor,
x: torch.Tensor,
t: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
Args:
eigenvectors: (B, d_cdc, d)
eigenvalues: (B, d_cdc)
x: (B, d) or (B, C, H, W) - will be flattened if needed
t: (B,) or scalar time
Returns:
result: Same shape as input x
Note:
Gradients flow through this function for backprop during training.
"""
# Store original shape to restore later
orig_shape = x.shape
# Flatten x if it's 4D
if x.dim() == 4:
B, C, H, W = x.shape
x = x.reshape(B, -1) # (B, C*H*W)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=x.device, dtype=x.dtype)
if t.dim() == 0:
t = t.expand(x.shape[0])
t = t.view(-1, 1)
# Early return for t=0 to avoid numerical errors
if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8):
return x.reshape(orig_shape)
# Check if CDC is disabled (all eigenvalues are zero)
# This happens for buckets with < k_neighbors samples
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
# Fallback to standard Gaussian noise (no CDC correction)
return x.reshape(orig_shape)
# Γ_b^(1/2) @ x using low-rank representation
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
# Σ_t @ x
result = (1 - t) * x + t * gamma_sqrt_x
# Restore original shape
result = result.reshape(orig_shape)
return result