mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
730 lines
27 KiB
Python
730 lines
27 KiB
Python
import logging
|
||
import torch
|
||
import numpy as np
|
||
import faiss # type: ignore
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
from safetensors.torch import save_file
|
||
from typing import List, Dict, Optional, Union, Tuple
|
||
from dataclasses import dataclass
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class LatentSample:
|
||
"""
|
||
Container for a single latent with metadata
|
||
"""
|
||
latent: np.ndarray # (d,) flattened latent vector
|
||
global_idx: int # Global index in dataset
|
||
shape: Tuple[int, ...] # Original shape before flattening (C, H, W)
|
||
metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.)
|
||
|
||
|
||
class CarreDuChampComputer:
|
||
"""
|
||
Core CDC-FM computation - agnostic to data source
|
||
Just handles the math for a batch of same-size latents
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
k_neighbors: int = 256,
|
||
k_bandwidth: int = 8,
|
||
d_cdc: int = 8,
|
||
gamma: float = 1.0,
|
||
device: str = 'cuda'
|
||
):
|
||
self.k = k_neighbors
|
||
self.k_bw = k_bandwidth
|
||
self.d_cdc = d_cdc
|
||
self.gamma = gamma
|
||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||
|
||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
Build k-NN graph using FAISS
|
||
|
||
Args:
|
||
latents_np: (N, d) numpy array of same-dimensional latents
|
||
|
||
Returns:
|
||
distances: (N, k_actual+1) distances (k_actual may be less than k if N is small)
|
||
indices: (N, k_actual+1) neighbor indices
|
||
"""
|
||
N, d = latents_np.shape
|
||
|
||
# Clamp k to available neighbors (can't have more neighbors than samples)
|
||
k_actual = min(self.k, N - 1)
|
||
|
||
# Ensure float32
|
||
if latents_np.dtype != np.float32:
|
||
latents_np = latents_np.astype(np.float32)
|
||
|
||
# Build FAISS index
|
||
index = faiss.IndexFlatL2(d)
|
||
|
||
if torch.cuda.is_available():
|
||
res = faiss.StandardGpuResources()
|
||
index = faiss.index_cpu_to_gpu(res, 0, index)
|
||
|
||
index.add(latents_np) # type: ignore
|
||
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
|
||
|
||
return distances, indices
|
||
|
||
@torch.no_grad()
|
||
def compute_gamma_b_single(
|
||
self,
|
||
point_idx: int,
|
||
latents_np: np.ndarray,
|
||
distances: np.ndarray,
|
||
indices: np.ndarray,
|
||
epsilon: np.ndarray
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Compute Γ_b for a single point
|
||
|
||
Args:
|
||
point_idx: Index of point to process
|
||
latents_np: (N, d) all latents in this batch
|
||
distances: (N, k+1) precomputed distances
|
||
indices: (N, k+1) precomputed neighbor indices
|
||
epsilon: (N,) bandwidth per point
|
||
|
||
Returns:
|
||
eigenvectors: (d_cdc, d) as half precision tensor
|
||
eigenvalues: (d_cdc,) as half precision tensor
|
||
"""
|
||
d = latents_np.shape[1]
|
||
|
||
# Get neighbors (exclude self)
|
||
neighbor_idx = indices[point_idx, 1:] # (k,)
|
||
neighbor_points = latents_np[neighbor_idx] # (k, d)
|
||
|
||
# Clamp distances to prevent overflow (max realistic L2 distance)
|
||
MAX_DISTANCE = 1e10
|
||
neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE)
|
||
neighbor_dists_sq = neighbor_dists ** 2 # (k,)
|
||
|
||
# Compute Gaussian kernel weights with numerical guards
|
||
eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero
|
||
eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10)
|
||
|
||
# Compute denominator with guard against overflow
|
||
denom = eps_i * eps_neighbors
|
||
denom = np.maximum(denom, 1e-20) # Additional guard
|
||
|
||
# Compute weights with safe exponential
|
||
exp_arg = -neighbor_dists_sq / denom
|
||
exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow
|
||
weights = np.exp(exp_arg)
|
||
|
||
# Normalize weights, handle edge case of all zeros
|
||
weight_sum = weights.sum()
|
||
if weight_sum < 1e-20 or not np.isfinite(weight_sum):
|
||
# Fallback to uniform weights
|
||
weights = np.ones_like(weights) / len(weights)
|
||
else:
|
||
weights = weights / weight_sum
|
||
|
||
# Compute local mean
|
||
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
|
||
|
||
# Center and weight for SVD
|
||
centered = neighbor_points - m_star
|
||
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
|
||
|
||
# Validate input is finite before SVD
|
||
if not np.all(np.isfinite(weighted_centered)):
|
||
logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback")
|
||
# Fallback: use uniform weights and simple centering
|
||
weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points)
|
||
m_star = np.mean(neighbor_points, axis=0)
|
||
centered = neighbor_points - m_star
|
||
weighted_centered = np.sqrt(weights_uniform)[:, None] * centered
|
||
|
||
# Move to GPU for SVD (100x speedup!)
|
||
weighted_centered_torch = torch.from_numpy(weighted_centered).to(
|
||
self.device, dtype=torch.float32
|
||
)
|
||
|
||
try:
|
||
U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False)
|
||
except RuntimeError as e:
|
||
logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}")
|
||
try:
|
||
U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False)
|
||
U = torch.from_numpy(U).to(self.device)
|
||
S = torch.from_numpy(S).to(self.device)
|
||
Vh = torch.from_numpy(Vh).to(self.device)
|
||
except np.linalg.LinAlgError as e2:
|
||
logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix")
|
||
# Return zero eigenvalues/vectors as fallback
|
||
return (
|
||
torch.zeros(self.d_cdc, d, dtype=torch.float16),
|
||
torch.zeros(self.d_cdc, dtype=torch.float16)
|
||
)
|
||
|
||
# Eigenvalues of Γ_b
|
||
eigenvalues_full = S ** 2
|
||
|
||
# Keep top d_cdc
|
||
if len(eigenvalues_full) >= self.d_cdc:
|
||
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
|
||
top_eigenvectors = Vh[top_idx, :] # (d_cdc, d)
|
||
else:
|
||
# Pad if k < d_cdc
|
||
top_eigenvalues = eigenvalues_full
|
||
top_eigenvectors = Vh
|
||
if len(eigenvalues_full) < self.d_cdc:
|
||
pad_size = self.d_cdc - len(eigenvalues_full)
|
||
top_eigenvalues = torch.cat([
|
||
top_eigenvalues,
|
||
torch.zeros(pad_size, device=self.device)
|
||
])
|
||
top_eigenvectors = torch.cat([
|
||
top_eigenvectors,
|
||
torch.zeros(pad_size, d, device=self.device)
|
||
])
|
||
|
||
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
|
||
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
|
||
# Then apply gamma: γc_i Γ̂(x^(i))
|
||
#
|
||
# Our implementation:
|
||
# 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor
|
||
# 2. Apply gamma hyperparameter - aligns with paper's γ global scaling
|
||
# 3. Clamp for numerical stability
|
||
#
|
||
# Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents)
|
||
# Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound
|
||
|
||
# Step 1: Normalize by the maximum eigenvalue to get relative scales
|
||
# This is the paper's 1/λ_1^i normalization factor
|
||
max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0
|
||
|
||
if max_eigenval > 1e-10:
|
||
# Scale so max eigenvalue = 1.0, preserving relative ratios
|
||
top_eigenvalues = top_eigenvalues / max_eigenval
|
||
|
||
# Step 2: Apply gamma and clamp to safe range
|
||
# Gamma is the paper's tuneable hyperparameter (defaults to 1.0)
|
||
# Clamping ensures numerical stability and prevents extreme values
|
||
top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0)
|
||
|
||
# Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0
|
||
# fp16 range: 6e-5 to 65,504, our values are well within this
|
||
eigenvectors_fp16 = top_eigenvectors.cpu().half()
|
||
eigenvalues_fp16 = top_eigenvalues.cpu().half()
|
||
|
||
# Cleanup
|
||
del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
return eigenvectors_fp16, eigenvalues_fp16
|
||
|
||
def compute_for_batch(
|
||
self,
|
||
latents_np: np.ndarray,
|
||
global_indices: List[int]
|
||
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
|
||
"""
|
||
Compute Γ_b for all points in a batch of same-size latents
|
||
|
||
Args:
|
||
latents_np: (N, d) numpy array
|
||
global_indices: List of global dataset indices for each latent
|
||
|
||
Returns:
|
||
Dict mapping global_idx -> (eigenvectors, eigenvalues)
|
||
"""
|
||
N, d = latents_np.shape
|
||
|
||
# Validate inputs
|
||
if len(global_indices) != N:
|
||
raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices")
|
||
|
||
print(f"Computing CDC for batch: {N} samples, dim={d}")
|
||
|
||
# Handle small sample cases - require minimum samples for meaningful k-NN
|
||
MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation
|
||
|
||
if N < MIN_SAMPLES_FOR_CDC:
|
||
print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)")
|
||
results = {}
|
||
for local_idx in range(N):
|
||
global_idx = global_indices[local_idx]
|
||
# Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x)
|
||
eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16)
|
||
eigvals = np.zeros(self.d_cdc, dtype=np.float16)
|
||
results[global_idx] = (eigvecs, eigvals)
|
||
return results
|
||
|
||
# Step 1: Build k-NN graph
|
||
print(" Building k-NN graph...")
|
||
distances, indices = self.compute_knn_graph(latents_np)
|
||
|
||
# Step 2: Compute bandwidth
|
||
# Use min to handle case where k_bw >= actual neighbors returned
|
||
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
|
||
epsilon = distances[:, k_bw_actual]
|
||
|
||
# Step 3: Compute Γ_b for each point
|
||
results = {}
|
||
print(" Computing Γ_b for each point...")
|
||
for local_idx in tqdm(range(N), desc=" Processing", leave=False):
|
||
global_idx = global_indices[local_idx]
|
||
eigvecs, eigvals = self.compute_gamma_b_single(
|
||
local_idx, latents_np, distances, indices, epsilon
|
||
)
|
||
results[global_idx] = (eigvecs, eigvals)
|
||
|
||
return results
|
||
|
||
|
||
class LatentBatcher:
|
||
"""
|
||
Collects variable-size latents and batches them by size
|
||
"""
|
||
|
||
def __init__(self, size_tolerance: float = 0.0):
|
||
"""
|
||
Args:
|
||
size_tolerance: If > 0, group latents within tolerance % of size
|
||
If 0, only exact size matches are batched
|
||
"""
|
||
self.size_tolerance = size_tolerance
|
||
self.samples: List[LatentSample] = []
|
||
|
||
def add_sample(self, sample: LatentSample):
|
||
"""Add a single latent sample"""
|
||
self.samples.append(sample)
|
||
|
||
def add_latent(
|
||
self,
|
||
latent: Union[np.ndarray, torch.Tensor],
|
||
global_idx: int,
|
||
shape: Optional[Tuple[int, ...]] = None,
|
||
metadata: Optional[Dict] = None
|
||
):
|
||
"""
|
||
Add a latent vector with automatic shape tracking
|
||
|
||
Args:
|
||
latent: Latent vector (any shape, will be flattened)
|
||
global_idx: Global index in dataset
|
||
shape: Original shape (if None, uses latent.shape)
|
||
metadata: Optional metadata dict
|
||
"""
|
||
# Convert to numpy and flatten
|
||
if isinstance(latent, torch.Tensor):
|
||
latent_np = latent.cpu().numpy()
|
||
else:
|
||
latent_np = latent
|
||
|
||
original_shape = shape if shape is not None else latent_np.shape
|
||
latent_flat = latent_np.flatten()
|
||
|
||
sample = LatentSample(
|
||
latent=latent_flat,
|
||
global_idx=global_idx,
|
||
shape=original_shape,
|
||
metadata=metadata
|
||
)
|
||
|
||
self.add_sample(sample)
|
||
|
||
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
|
||
"""
|
||
Group samples by exact shape to avoid resizing distortion.
|
||
|
||
Each bucket contains only samples with identical latent dimensions.
|
||
Buckets with fewer than k_neighbors samples will be skipped during CDC
|
||
computation and fall back to standard Gaussian noise.
|
||
|
||
Returns:
|
||
Dict mapping exact_shape -> list of samples with that shape
|
||
"""
|
||
batches = {}
|
||
|
||
for sample in self.samples:
|
||
shape_key = sample.shape
|
||
|
||
# Group by exact shape only - no aspect ratio grouping or resizing
|
||
if shape_key not in batches:
|
||
batches[shape_key] = []
|
||
|
||
batches[shape_key].append(sample)
|
||
|
||
return batches
|
||
|
||
def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str:
|
||
"""
|
||
Get aspect ratio category for grouping.
|
||
Groups images by aspect ratio bins to ensure sufficient samples.
|
||
|
||
For shape (C, H, W), computes aspect ratio H/W and bins it.
|
||
"""
|
||
if len(shape) < 3:
|
||
return "unknown"
|
||
|
||
# Extract spatial dimensions (H, W)
|
||
h, w = shape[-2], shape[-1]
|
||
aspect_ratio = h / w
|
||
|
||
# Define aspect ratio bins (±15% tolerance)
|
||
# Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16)
|
||
bins = [
|
||
(0.5, 0.65, "9:16"), # Portrait tall
|
||
(0.65, 0.85, "3:4"), # Portrait
|
||
(0.85, 1.15, "1:1"), # Square
|
||
(1.15, 1.50, "4:3"), # Landscape
|
||
(1.50, 2.0, "16:9"), # Landscape wide
|
||
(2.0, 3.0, "21:9"), # Ultra wide
|
||
]
|
||
|
||
for min_ratio, max_ratio, label in bins:
|
||
if min_ratio <= aspect_ratio < max_ratio:
|
||
return label
|
||
|
||
# Fallback for extreme ratios
|
||
if aspect_ratio < 0.5:
|
||
return "ultra_tall"
|
||
else:
|
||
return "ultra_wide"
|
||
|
||
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
|
||
"""Check if two shapes are within tolerance"""
|
||
if len(shape1) != len(shape2):
|
||
return False
|
||
|
||
size1 = np.prod(shape1)
|
||
size2 = np.prod(shape2)
|
||
|
||
ratio = abs(size1 - size2) / max(size1, size2)
|
||
return ratio <= self.size_tolerance
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
|
||
class CDCPreprocessor:
|
||
"""
|
||
High-level CDC preprocessing coordinator
|
||
Handles variable-size latents by batching and delegating to CarreDuChampComputer
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
k_neighbors: int = 256,
|
||
k_bandwidth: int = 8,
|
||
d_cdc: int = 8,
|
||
gamma: float = 1.0,
|
||
device: str = 'cuda',
|
||
size_tolerance: float = 0.0,
|
||
debug: bool = False
|
||
):
|
||
self.computer = CarreDuChampComputer(
|
||
k_neighbors=k_neighbors,
|
||
k_bandwidth=k_bandwidth,
|
||
d_cdc=d_cdc,
|
||
gamma=gamma,
|
||
device=device
|
||
)
|
||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||
self.debug = debug
|
||
|
||
def add_latent(
|
||
self,
|
||
latent: Union[np.ndarray, torch.Tensor],
|
||
global_idx: int,
|
||
shape: Optional[Tuple[int, ...]] = None,
|
||
metadata: Optional[Dict] = None
|
||
):
|
||
"""
|
||
Add a single latent to the preprocessing queue
|
||
|
||
Args:
|
||
latent: Latent vector (will be flattened)
|
||
global_idx: Global dataset index
|
||
shape: Original shape (C, H, W)
|
||
metadata: Optional metadata
|
||
"""
|
||
self.batcher.add_latent(latent, global_idx, shape, metadata)
|
||
|
||
def compute_all(self, save_path: Union[str, Path]) -> Path:
|
||
"""
|
||
Compute Γ_b for all added latents and save to safetensors
|
||
|
||
Args:
|
||
save_path: Path to save the results
|
||
|
||
Returns:
|
||
Path to saved file
|
||
"""
|
||
save_path = Path(save_path)
|
||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Get batches by exact size (no resizing)
|
||
batches = self.batcher.get_batches()
|
||
|
||
# Count samples that will get CDC vs fallback
|
||
k_neighbors = self.computer.k
|
||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||
|
||
if self.debug:
|
||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||
else:
|
||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||
|
||
# Storage for results
|
||
all_results = {}
|
||
|
||
# Process each bucket with progress bar
|
||
bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items()
|
||
|
||
for shape, samples in bucket_iter:
|
||
num_samples = len(samples)
|
||
|
||
if self.debug:
|
||
print(f"\n{'='*60}")
|
||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||
print(f"{'='*60}")
|
||
|
||
# Check if bucket has enough samples for k-NN
|
||
if num_samples < k_neighbors:
|
||
if self.debug:
|
||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||
|
||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||
C, H, W = shape
|
||
d = C * H * W
|
||
|
||
for sample in samples:
|
||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||
|
||
continue
|
||
|
||
# Collect latents (no resizing needed - all same shape)
|
||
latents_list = []
|
||
global_indices = []
|
||
|
||
for sample in samples:
|
||
global_indices.append(sample.global_idx)
|
||
latents_list.append(sample.latent) # Already flattened
|
||
|
||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||
|
||
# Compute CDC for this batch
|
||
if self.debug:
|
||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||
|
||
# No resizing needed - eigenvectors are already correct size
|
||
if self.debug:
|
||
print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)")
|
||
|
||
# Merge into overall results
|
||
all_results.update(batch_results)
|
||
|
||
# Save to safetensors
|
||
if self.debug:
|
||
print(f"\n{'='*60}")
|
||
print("Saving results...")
|
||
print(f"{'='*60}")
|
||
|
||
tensors_dict = {
|
||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||
'metadata/k_neighbors': torch.tensor([self.computer.k]),
|
||
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
|
||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||
}
|
||
|
||
# Add shape information and CDC results for each sample
|
||
# Use image_key as the identifier
|
||
for sample in self.batcher.samples:
|
||
image_key = sample.metadata['image_key']
|
||
tensors_dict[f'shapes/{image_key}'] = torch.tensor(sample.shape)
|
||
|
||
# Get CDC results for this sample
|
||
if sample.global_idx in all_results:
|
||
eigvecs, eigvals = all_results[sample.global_idx]
|
||
|
||
# Convert numpy arrays to torch tensors
|
||
if isinstance(eigvecs, np.ndarray):
|
||
eigvecs = torch.from_numpy(eigvecs)
|
||
if isinstance(eigvals, np.ndarray):
|
||
eigvals = torch.from_numpy(eigvals)
|
||
|
||
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
|
||
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
|
||
|
||
save_file(tensors_dict, save_path)
|
||
|
||
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
|
||
logger.info(f"Saved to {save_path}")
|
||
logger.info(f"File size: {file_size_gb:.2f} GB")
|
||
|
||
return save_path
|
||
|
||
|
||
class GammaBDataset:
|
||
"""
|
||
Efficient loader for Γ_b matrices during training
|
||
Handles variable-size latents
|
||
"""
|
||
|
||
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
|
||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||
self.gamma_b_path = Path(gamma_b_path)
|
||
|
||
# Load metadata
|
||
logger.info(f"Loading Γ_b from {gamma_b_path}...")
|
||
from safetensors import safe_open
|
||
|
||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
|
||
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
|
||
|
||
# Cache all shapes in memory to avoid repeated I/O during training
|
||
# Loading once at init is much faster than opening the file every training step
|
||
self.shapes_cache = {}
|
||
# Get all shape keys (they're stored as shapes/{image_key})
|
||
all_keys = f.keys()
|
||
shape_keys = [k for k in all_keys if k.startswith('shapes/')]
|
||
for shape_key in shape_keys:
|
||
image_key = shape_key.replace('shapes/', '')
|
||
shape_tensor = f.get_tensor(shape_key)
|
||
self.shapes_cache[image_key] = tuple(shape_tensor.numpy().tolist())
|
||
|
||
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||
|
||
@torch.no_grad()
|
||
def get_gamma_b_sqrt(
|
||
self,
|
||
image_keys: Union[List[str], List],
|
||
device: Optional[str] = None
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Get Γ_b^(1/2) components for a batch of image_keys
|
||
|
||
Args:
|
||
image_keys: List of image_key strings
|
||
device: Device to load to (defaults to self.device)
|
||
|
||
Returns:
|
||
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
|
||
eigenvalues: (B, d_cdc)
|
||
"""
|
||
if device is None:
|
||
device = self.device
|
||
|
||
# Load from safetensors
|
||
from safetensors import safe_open
|
||
|
||
eigenvectors_list = []
|
||
eigenvalues_list = []
|
||
|
||
with safe_open(str(self.gamma_b_path), framework="pt", device=str(device)) as f:
|
||
for image_key in image_keys:
|
||
eigvecs = f.get_tensor(f'eigenvectors/{image_key}').float()
|
||
eigvals = f.get_tensor(f'eigenvalues/{image_key}').float()
|
||
|
||
eigenvectors_list.append(eigvecs)
|
||
eigenvalues_list.append(eigvals)
|
||
|
||
# Stack - all should have same d_cdc and d within a batch (enforced by bucketing)
|
||
# Check if all eigenvectors have the same dimension
|
||
dims = [ev.shape[1] for ev in eigenvectors_list]
|
||
if len(set(dims)) > 1:
|
||
# Dimension mismatch! This shouldn't happen with proper bucketing
|
||
# but can occur if batch contains mixed sizes
|
||
raise RuntimeError(
|
||
f"CDC eigenvector dimension mismatch in batch: {set(dims)}. "
|
||
f"Image keys: {image_keys}. "
|
||
f"This means the training batch contains images of different sizes, "
|
||
f"which violates CDC's requirement for uniform latent dimensions per batch. "
|
||
f"Check that your dataloader buckets are configured correctly."
|
||
)
|
||
|
||
eigenvectors = torch.stack(eigenvectors_list, dim=0)
|
||
eigenvalues = torch.stack(eigenvalues_list, dim=0)
|
||
|
||
return eigenvectors, eigenvalues
|
||
|
||
def get_shape(self, image_key: str) -> Tuple[int, ...]:
|
||
"""Get the original shape for a sample (cached in memory)"""
|
||
return self.shapes_cache[image_key]
|
||
|
||
def compute_sigma_t_x(
|
||
self,
|
||
eigenvectors: torch.Tensor,
|
||
eigenvalues: torch.Tensor,
|
||
x: torch.Tensor,
|
||
t: Union[float, torch.Tensor]
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2)
|
||
|
||
Args:
|
||
eigenvectors: (B, d_cdc, d)
|
||
eigenvalues: (B, d_cdc)
|
||
x: (B, d) or (B, C, H, W) - will be flattened if needed
|
||
t: (B,) or scalar time
|
||
|
||
Returns:
|
||
result: Same shape as input x
|
||
|
||
Note:
|
||
Gradients flow through this function for backprop during training.
|
||
"""
|
||
# Store original shape to restore later
|
||
orig_shape = x.shape
|
||
|
||
# Flatten x if it's 4D
|
||
if x.dim() == 4:
|
||
B, C, H, W = x.shape
|
||
x = x.reshape(B, -1) # (B, C*H*W)
|
||
|
||
if not isinstance(t, torch.Tensor):
|
||
t = torch.tensor(t, device=x.device, dtype=x.dtype)
|
||
|
||
if t.dim() == 0:
|
||
t = t.expand(x.shape[0])
|
||
|
||
t = t.view(-1, 1)
|
||
|
||
# Early return for t=0 to avoid numerical errors
|
||
if torch.allclose(t, torch.zeros_like(t), atol=1e-8):
|
||
return x.reshape(orig_shape)
|
||
|
||
# Check if CDC is disabled (all eigenvalues are zero)
|
||
# This happens for buckets with < k_neighbors samples
|
||
if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8):
|
||
# Fallback to standard Gaussian noise (no CDC correction)
|
||
return x.reshape(orig_shape)
|
||
|
||
# Γ_b^(1/2) @ x using low-rank representation
|
||
Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x)
|
||
sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10))
|
||
sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x
|
||
gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x)
|
||
|
||
# Σ_t @ x
|
||
result = (1 - t) * x + t * gamma_sqrt_x
|
||
|
||
# Restore original shape
|
||
result = result.reshape(orig_shape)
|
||
|
||
return result
|