mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Formatting cleanup
This commit is contained in:
@@ -27,7 +27,7 @@ class CarreDuChampComputer:
|
||||
Core CDC-FM computation - agnostic to data source
|
||||
Just handles the math for a batch of same-size latents
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
@@ -41,7 +41,7 @@ class CarreDuChampComputer:
|
||||
self.d_cdc = d_cdc
|
||||
self.gamma = gamma
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Build k-NN graph using FAISS
|
||||
@@ -73,7 +73,7 @@ class CarreDuChampComputer:
|
||||
distances, indices = index.search(latents_np, k_actual + 1) # type: ignore
|
||||
|
||||
return distances, indices
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_gamma_b_single(
|
||||
self,
|
||||
@@ -128,10 +128,10 @@ class CarreDuChampComputer:
|
||||
weights = np.ones_like(weights) / len(weights)
|
||||
else:
|
||||
weights = weights / weight_sum
|
||||
|
||||
|
||||
# Compute local mean
|
||||
m_star = np.sum(weights[:, None] * neighbor_points, axis=0)
|
||||
|
||||
|
||||
# Center and weight for SVD
|
||||
centered = neighbor_points - m_star
|
||||
weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d)
|
||||
@@ -166,10 +166,10 @@ class CarreDuChampComputer:
|
||||
torch.zeros(self.d_cdc, d, dtype=torch.float16),
|
||||
torch.zeros(self.d_cdc, dtype=torch.float16)
|
||||
)
|
||||
|
||||
|
||||
# Eigenvalues of Γ_b
|
||||
eigenvalues_full = S ** 2
|
||||
|
||||
|
||||
# Keep top d_cdc
|
||||
if len(eigenvalues_full) >= self.d_cdc:
|
||||
top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc)
|
||||
@@ -188,7 +188,7 @@ class CarreDuChampComputer:
|
||||
top_eigenvectors,
|
||||
torch.zeros(pad_size, d, device=self.device)
|
||||
])
|
||||
|
||||
|
||||
# Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33)
|
||||
# Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max)
|
||||
# Then apply gamma: γc_i Γ̂(x^(i))
|
||||
@@ -225,7 +225,7 @@ class CarreDuChampComputer:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return eigenvectors_fp16, eigenvalues_fp16
|
||||
|
||||
|
||||
def compute_for_batch(
|
||||
self,
|
||||
latents_np: np.ndarray,
|
||||
@@ -266,12 +266,12 @@ class CarreDuChampComputer:
|
||||
# Step 1: Build k-NN graph
|
||||
print(" Building k-NN graph...")
|
||||
distances, indices = self.compute_knn_graph(latents_np)
|
||||
|
||||
|
||||
# Step 2: Compute bandwidth
|
||||
# Use min to handle case where k_bw >= actual neighbors returned
|
||||
k_bw_actual = min(self.k_bw, distances.shape[1] - 1)
|
||||
epsilon = distances[:, k_bw_actual]
|
||||
|
||||
|
||||
# Step 3: Compute Γ_b for each point
|
||||
results = {}
|
||||
print(" Computing Γ_b for each point...")
|
||||
@@ -281,7 +281,7 @@ class CarreDuChampComputer:
|
||||
local_idx, latents_np, distances, indices, epsilon
|
||||
)
|
||||
results[global_idx] = (eigvecs, eigvals)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -289,7 +289,7 @@ class LatentBatcher:
|
||||
"""
|
||||
Collects variable-size latents and batches them by size
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, size_tolerance: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
@@ -298,11 +298,11 @@ class LatentBatcher:
|
||||
"""
|
||||
self.size_tolerance = size_tolerance
|
||||
self.samples: List[LatentSample] = []
|
||||
|
||||
|
||||
def add_sample(self, sample: LatentSample):
|
||||
"""Add a single latent sample"""
|
||||
self.samples.append(sample)
|
||||
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
@@ -324,19 +324,19 @@ class LatentBatcher:
|
||||
latent_np = latent.cpu().numpy()
|
||||
else:
|
||||
latent_np = latent
|
||||
|
||||
|
||||
original_shape = shape if shape is not None else latent_np.shape
|
||||
latent_flat = latent_np.flatten()
|
||||
|
||||
|
||||
sample = LatentSample(
|
||||
latent=latent_flat,
|
||||
global_idx=global_idx,
|
||||
shape=original_shape,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
self.add_sample(sample)
|
||||
|
||||
|
||||
def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]:
|
||||
"""
|
||||
Group samples by exact shape to avoid resizing distortion.
|
||||
@@ -395,18 +395,18 @@ class LatentBatcher:
|
||||
return "ultra_tall"
|
||||
else:
|
||||
return "ultra_wide"
|
||||
|
||||
|
||||
def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool:
|
||||
"""Check if two shapes are within tolerance"""
|
||||
if len(shape1) != len(shape2):
|
||||
return False
|
||||
|
||||
|
||||
size1 = np.prod(shape1)
|
||||
size2 = np.prod(shape2)
|
||||
|
||||
|
||||
ratio = abs(size1 - size2) / max(size1, size2)
|
||||
return ratio <= self.size_tolerance
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
@@ -416,7 +416,7 @@ class CDCPreprocessor:
|
||||
High-level CDC preprocessing coordinator
|
||||
Handles variable-size latents by batching and delegating to CarreDuChampComputer
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_neighbors: int = 256,
|
||||
@@ -436,7 +436,7 @@ class CDCPreprocessor:
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
latent: Union[np.ndarray, torch.Tensor],
|
||||
@@ -454,7 +454,7 @@ class CDCPreprocessor:
|
||||
metadata: Optional metadata
|
||||
"""
|
||||
self.batcher.add_latent(latent, global_idx, shape, metadata)
|
||||
|
||||
|
||||
def compute_all(self, save_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Compute Γ_b for all added latents and save to safetensors
|
||||
@@ -467,7 +467,7 @@ class CDCPreprocessor:
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Get batches by exact size (no resizing)
|
||||
batches = self.batcher.get_batches()
|
||||
|
||||
@@ -541,14 +541,14 @@ class CDCPreprocessor:
|
||||
print(f"\n{'='*60}")
|
||||
print("Saving results...")
|
||||
print(f"{'='*60}")
|
||||
|
||||
|
||||
tensors_dict = {
|
||||
'metadata/num_samples': torch.tensor([len(all_results)]),
|
||||
'metadata/k_neighbors': torch.tensor([self.computer.k]),
|
||||
'metadata/d_cdc': torch.tensor([self.computer.d_cdc]),
|
||||
'metadata/gamma': torch.tensor([self.computer.gamma]),
|
||||
}
|
||||
|
||||
|
||||
# Add shape information and CDC results for each sample
|
||||
# Use image_key as the identifier
|
||||
for sample in self.batcher.samples:
|
||||
@@ -567,7 +567,7 @@ class CDCPreprocessor:
|
||||
|
||||
tensors_dict[f'eigenvectors/{image_key}'] = eigvecs
|
||||
tensors_dict[f'eigenvalues/{image_key}'] = eigvals
|
||||
|
||||
|
||||
save_file(tensors_dict, save_path)
|
||||
|
||||
file_size_gb = save_path.stat().st_size / 1024 / 1024 / 1024
|
||||
@@ -582,11 +582,11 @@ class GammaBDataset:
|
||||
Efficient loader for Γ_b matrices during training
|
||||
Handles variable-size latents
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, gamma_b_path: Union[str, Path], device: str = 'cuda'):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.gamma_b_path = Path(gamma_b_path)
|
||||
|
||||
|
||||
# Load metadata
|
||||
logger.info(f"Loading Γ_b from {gamma_b_path}...")
|
||||
from safetensors import safe_open
|
||||
@@ -608,7 +608,7 @@ class GammaBDataset:
|
||||
|
||||
logger.info(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||||
logger.info(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_gamma_b_sqrt(
|
||||
self,
|
||||
@@ -661,11 +661,11 @@ class GammaBDataset:
|
||||
eigenvalues = torch.stack(eigenvalues_list, dim=0)
|
||||
|
||||
return eigenvectors, eigenvalues
|
||||
|
||||
|
||||
def get_shape(self, image_key: str) -> Tuple[int, ...]:
|
||||
"""Get the original shape for a sample (cached in memory)"""
|
||||
return self.shapes_cache[image_key]
|
||||
|
||||
|
||||
def compute_sigma_t_x(
|
||||
self,
|
||||
eigenvectors: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user