Add adaptive k_neighbors support for CDC-FM

- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size
- Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16)
- Fixed mode (default): Skip buckets with < k_neighbors samples
- Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size
- Update CDCPreprocessor to support adaptive k per bucket
- Add metadata tracking for adaptive_k and min_bucket_size
- Add comprehensive pytest tests for adaptive k behavior

This allows CDC-FM to work effectively with multi-resolution bucketing where
bucket sizes may vary widely. Users can choose between strict paper methodology
(fixed k) or pragmatic approach (adaptive k).
This commit is contained in:
rockerBOO
2025-10-09 23:16:44 -04:00
parent f450443fe4
commit 7ca799ca26
5 changed files with 317 additions and 20 deletions

View File

@@ -425,7 +425,9 @@ class CDCPreprocessor:
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0,
debug: bool = False
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16
):
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
@@ -436,6 +438,8 @@ class CDCPreprocessor:
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
def add_latent(
self,
@@ -473,15 +477,23 @@ class CDCPreprocessor:
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
if self.adaptive_k:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
else:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
if self.adaptive_k:
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
mode = "adaptive" if self.adaptive_k else "fixed"
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
@@ -497,22 +509,46 @@ class CDCPreprocessor:
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Check if bucket has enough samples for k-NN
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Determine effective k for this bucket
if self.adaptive_k:
# Adaptive mode: skip if below minimum, otherwise use best available k
if num_samples < min_threshold:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
continue
# Use adaptive k for this bucket
k_effective = min(k_neighbors, num_samples - 1)
else:
# Fixed mode: skip if below k_neighbors
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
k_effective = k_neighbors
# Collect latents (no resizing needed - all same shape)
latents_list = []
@@ -524,10 +560,18 @@ class CDCPreprocessor:
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch
# Compute CDC for this batch with effective k
if self.debug:
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
if self.adaptive_k and k_effective < k_neighbors:
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
else:
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
# Temporarily override k for this bucket
original_k = self.computer.k
self.computer.k = k_effective
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
self.computer.k = original_k
# No resizing needed - eigenvectors are already correct size
if self.debug:

View File

@@ -2707,6 +2707,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
) -> str:
"""
Cache CDC Γ_b matrices for all latents in the dataset
@@ -2751,7 +2753,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
from library.cdc_fm import CDCPreprocessor
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
)
# Get caching strategy for loading latents