diff --git a/library/cdc_fm.py b/library/cdc_fm.py index 81f9de29..8ecc773d 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -27,7 +27,7 @@ class CarreDuChampComputer: Core CDC-FM computation - agnostic to data source Just handles the math for a batch of same-size latents """ - + def __init__( self, k_neighbors: int = 256, @@ -41,7 +41,7 @@ class CarreDuChampComputer: self.d_cdc = d_cdc self.gamma = gamma self.device = torch.device(device if torch.cuda.is_available() else 'cpu') - + def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Build k-NN graph using FAISS @@ -73,7 +73,7 @@ class CarreDuChampComputer: distances, indices = index.search(latents_np, k_actual + 1) # type: ignore return distances, indices - + @torch.no_grad() def compute_gamma_b_single( self, @@ -128,10 +128,10 @@ class CarreDuChampComputer: weights = np.ones_like(weights) / len(weights) else: weights = weights / weight_sum - + # Compute local mean m_star = np.sum(weights[:, None] * neighbor_points, axis=0) - + # Center and weight for SVD centered = neighbor_points - m_star weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d) @@ -166,10 +166,10 @@ class CarreDuChampComputer: torch.zeros(self.d_cdc, d, dtype=torch.float16), torch.zeros(self.d_cdc, dtype=torch.float16) ) - + # Eigenvalues of Γ_b eigenvalues_full = S ** 2 - + # Keep top d_cdc if len(eigenvalues_full) >= self.d_cdc: top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc) @@ -188,7 +188,7 @@ class CarreDuChampComputer: top_eigenvectors, torch.zeros(pad_size, d, device=self.device) ]) - + # Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33) # Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max) # Then apply gamma: γc_i Γ̂(x^(i)) @@ -225,7 +225,7 @@ class CarreDuChampComputer: torch.cuda.empty_cache() return eigenvectors_fp16, eigenvalues_fp16 - + def compute_for_batch( self, latents_np: np.ndarray, @@ -266,12 +266,12 @@ class CarreDuChampComputer: # Step 1: Build k-NN graph print(" Building k-NN graph...") distances, indices = self.compute_knn_graph(latents_np) - + # Step 2: Compute bandwidth # Use min to handle case where k_bw >= actual neighbors returned k_bw_actual = min(self.k_bw, distances.shape[1] - 1) epsilon = distances[:, k_bw_actual] - + # Step 3: Compute Γ_b for each point results = {} print(" Computing Γ_b for each point...") @@ -281,7 +281,7 @@ class CarreDuChampComputer: local_idx, latents_np, distances, indices, epsilon ) results[global_idx] = (eigvecs, eigvals) - + return results @@ -289,7 +289,7 @@ class LatentBatcher: """ Collects variable-size latents and batches them by size """ - + def __init__(self, size_tolerance: float = 0.0): """ Args: @@ -298,11 +298,11 @@ class LatentBatcher: """ self.size_tolerance = size_tolerance self.samples: List[LatentSample] = [] - + def add_sample(self, sample: LatentSample): """Add a single latent sample""" self.samples.append(sample) - + def add_latent( self, latent: Union[np.ndarray, torch.Tensor], @@ -324,19 +324,19 @@ class LatentBatcher: latent_np = latent.cpu().numpy() else: latent_np = latent - + original_shape = shape if shape is not None else latent_np.shape latent_flat = latent_np.flatten() - + sample = LatentSample( latent=latent_flat, global_idx=global_idx, shape=original_shape, metadata=metadata ) - + self.add_sample(sample) - + def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]: """ Group samples by exact shape to avoid resizing distortion. @@ -395,18 +395,18 @@ class LatentBatcher: return "ultra_tall" else: return "ultra_wide" - + def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool: """Check if two shapes are within tolerance""" if len(shape1) != len(shape2): return False - + size1 = np.prod(shape1) size2 = np.prod(shape2) - + ratio = abs(size1 - size2) / max(size1, size2) return ratio <= self.size_tolerance - + def __len__(self): return len(self.samples) @@ -416,7 +416,7 @@ class CDCPreprocessor: High-level CDC preprocessing coordinator Handles variable-size latents by batching and delegating to CarreDuChampComputer """ - + def __init__( self, k_neighbors: int = 256, @@ -436,7 +436,7 @@ class CDCPreprocessor: ) self.batcher = LatentBatcher(size_tolerance=size_tolerance) self.debug = debug - + def add_latent( self, latent: Union[np.ndarray, torch.Tensor], @@ -454,7 +454,7 @@ class CDCPreprocessor: metadata: Optional metadata """ self.batcher.add_latent(latent, global_idx, shape, metadata) - + def compute_all(self, save_path: Union[str, Path]) -> Path: """ Compute Γ_b for all added latents and save to safetensors @@ -467,7 +467,7 @@ class CDCPreprocessor: """ save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) - + # Get batches by exact size (no resizing) batches = self.batcher.get_batches() @@ -541,14 +541,14 @@ class CDCPreprocessor: print(f"\n{'='*60}") print("Saving results...") print(f"{'='*60}") - + tensors_dict = { 'metadata/num_samples': torch.tensor([len(all_results)]), 'metadata/k_neighbors': torch.tensor([self.computer.k]), 'metadata/d_cdc': torch.tensor([self.computer.d_cdc]), 'metadata/gamma': torch.tensor([self.computer.gamma]), } - + # Add shape information and CDC results for each sample # Use image_key as the identifier for sample in self.batcher.samples: @@ -567,7 +567,7 @@ class CDCPreprocessor: tensors_dict[f'eigenvectors/{image_key}'] = eigvecs tensors_dict[f'eigenvalues/{image_key}'] = eigvals - + save_file(tensors_dict, save_path) file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024 @@ -582,11 +582,11 @@ class GammaBDataset: Efficient loader for Γ_b matrices during training Handles variable-size latents """ - + def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.gamma_b_path = Path(gamma_b_path) - + # Load metadata logger.info(f"Loading Γ_b from {gamma_b_path}...") from safetensors import safe_open @@ -608,7 +608,7 @@ class GammaBDataset: logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})") logger.info(f"Cached {len(self.shapes_cache)} shapes in memory") - + @torch.no_grad() def get_gamma_b_sqrt( self, @@ -661,11 +661,11 @@ class GammaBDataset: eigenvalues = torch.stack(eigenvalues_list, dim=0) return eigenvectors, eigenvalues - + def get_shape(self, image_key: str) -> Tuple[int, ...]: """Get the original shape for a sample (cached in memory)""" return self.shapes_cache[image_key] - + def compute_sigma_t_x( self, eigenvectors: torch.Tensor,