Formatting cleanup

This commit is contained in:
rockerBOO
2025-10-09 18:26:25 -04:00
parent c8a4e99074
commit f128f5a645

View File

@@ -27,7 +27,7 @@ class CarreDuChampComputer:
Core CDC-FM computation - agnostic to data source
Just handles the math for a batch of same-size latents
"""
def __init__(
self,
k_neighbors: int = 256,
@@ -41,7 +41,7 @@ class CarreDuChampComputer:
self.d_cdc = d_cdc
self.gamma = gamma
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Build k-NN graph using FAISS
@@ -73,7 +73,7 @@ class CarreDuChampComputer:
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
return distances, indices
@torch.no_grad()
def compute_gamma_b_single(
self,
@@ -128,10 +128,10 @@ class CarreDuChampComputer:
weights = np.ones_like(weights) / len(weights)
else:
weights = weights / weight_sum
# Compute local mean
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
# Center and weight for SVD
centered = neighbor_points - m_star
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
@@ -166,10 +166,10 @@ class CarreDuChampComputer:
torch.zeros(self.d_cdc, d, dtype=torch.float16),
torch.zeros(self.d_cdc, dtype=torch.float16)
)
# Eigenvalues of Γ_b
eigenvalues_full = S ** 2
# Keep top d_cdc
if len(eigenvalues_full) >= self.d_cdc:
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
@@ -188,7 +188,7 @@ class CarreDuChampComputer:
top_eigenvectors,
torch.zeros(pad_size, d, device=self.device)
])
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
# Then apply gamma: γc_i Γ̂(x^(i))
@@ -225,7 +225,7 @@ class CarreDuChampComputer:
torch.cuda.empty_cache()
return eigenvectors_fp16, eigenvalues_fp16
def compute_for_batch(
self,
latents_np: np.ndarray,
@@ -266,12 +266,12 @@ class CarreDuChampComputer:
# Step 1: Build k-NN graph
print(" Building k-NN graph...")
distances, indices = self.compute_knn_graph(latents_np)
# Step 2: Compute bandwidth
# Use min to handle case where k_bw >= actual neighbors returned
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
epsilon = distances[:, k_bw_actual]
# Step 3: Compute Γ_b for each point
results = {}
print(" Computing Γ_b for each point...")
@@ -281,7 +281,7 @@ class CarreDuChampComputer:
local_idx, latents_np, distances, indices, epsilon
)
results[global_idx] = (eigvecs, eigvals)
return results
@@ -289,7 +289,7 @@ class LatentBatcher:
"""
Collects variable-size latents and batches them by size
"""
def __init__(self, size_tolerance: float = 0.0):
"""
Args:
@@ -298,11 +298,11 @@ class LatentBatcher:
"""
self.size_tolerance = size_tolerance
self.samples: List[LatentSample] = []
def add_sample(self, sample: LatentSample):
"""Add a single latent sample"""
self.samples.append(sample)
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
@@ -324,19 +324,19 @@ class LatentBatcher:
latent_np = latent.cpu().numpy()
else:
latent_np = latent
original_shape = shape if shape is not None else latent_np.shape
latent_flat = latent_np.flatten()
sample = LatentSample(
latent=latent_flat,
global_idx=global_idx,
shape=original_shape,
metadata=metadata
)
self.add_sample(sample)
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
"""
Group samples by exact shape to avoid resizing distortion.
@@ -395,18 +395,18 @@ class LatentBatcher:
return "ultra_tall"
else:
return "ultra_wide"
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
"""Check if two shapes are within tolerance"""
if len(shape1) != len(shape2):
return False
size1 = np.prod(shape1)
size2 = np.prod(shape2)
ratio = abs(size1 - size2) / max(size1, size2)
return ratio <= self.size_tolerance
def __len__(self):
return len(self.samples)
@@ -416,7 +416,7 @@ class CDCPreprocessor:
High-level CDC preprocessing coordinator
Handles variable-size latents by batching and delegating to CarreDuChampComputer
"""
def __init__(
self,
k_neighbors: int = 256,
@@ -436,7 +436,7 @@ class CDCPreprocessor:
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
def add_latent(
self,
latent: Union[np.ndarray, torch.Tensor],
@@ -454,7 +454,7 @@ class CDCPreprocessor:
metadata: Optional metadata
"""
self.batcher.add_latent(latent, global_idx, shape, metadata)
def compute_all(self, save_path: Union[str, Path]) -> Path:
"""
Compute Γ_b for all added latents and save to safetensors
@@ -467,7 +467,7 @@ class CDCPreprocessor:
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Get batches by exact size (no resizing)
batches = self.batcher.get_batches()
@@ -541,14 +541,14 @@ class CDCPreprocessor:
print(f"\n{'='*60}")
print("Saving results...")
print(f"{'='*60}")
tensors_dict = {
'metadata/num_samples': torch.tensor([len(all_results)]),
'metadata/k_neighbors': torch.tensor([self.computer.k]),
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
'metadata/gamma': torch.tensor([self.computer.gamma]),
}
# Add shape information and CDC results for each sample
# Use image_key as the identifier
for sample in self.batcher.samples:
@@ -567,7 +567,7 @@ class CDCPreprocessor:
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
save_file(tensors_dict, save_path)
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
@@ -582,11 +582,11 @@ class GammaBDataset:
Efficient loader for Γ_b matrices during training
Handles variable-size latents
"""
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.gamma_b_path = Path(gamma_b_path)
# Load metadata
logger.info(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
@@ -608,7 +608,7 @@ class GammaBDataset:
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
@torch.no_grad()
def get_gamma_b_sqrt(
self,
@@ -661,11 +661,11 @@ class GammaBDataset:
eigenvalues = torch.stack(eigenvalues_list, dim=0)
return eigenvectors, eigenvalues
def get_shape(self, image_key: str) -> Tuple[int, ...]:
"""Get the original shape for a sample (cached in memory)"""
return self.shapes_cache[image_key]
def compute_sigma_t_x(
self,
eigenvectors: torch.Tensor,