mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Add adaptive k_neighbors support for CDC-FM
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
This commit is contained in:
@@ -425,7 +425,9 @@ class CDCPreprocessor:
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16
|
||||
):
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
@@ -436,6 +438,8 @@ class CDCPreprocessor:
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
self.adaptive_k = adaptive_k
|
||||
self.min_bucket_size = min_bucket_size
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
@@ -473,15 +477,23 @@ class CDCPreprocessor:
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
|
||||
|
||||
if self.adaptive_k:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
|
||||
else:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
if self.debug:
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
if self.adaptive_k:
|
||||
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
|
||||
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
else:
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
mode = "adaptive" if self.adaptive_k else "fixed"
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
@@ -497,22 +509,46 @@ class CDCPreprocessor:
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if bucket has enough samples for k-NN
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
# Determine effective k for this bucket
|
||||
if self.adaptive_k:
|
||||
# Adaptive mode: skip if below minimum, otherwise use best available k
|
||||
if num_samples < min_threshold:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
# Use adaptive k for this bucket
|
||||
k_effective = min(k_neighbors, num_samples - 1)
|
||||
else:
|
||||
# Fixed mode: skip if below k_neighbors
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
k_effective = k_neighbors
|
||||
|
||||
# Collect latents (no resizing needed - all same shape)
|
||||
latents_list = []
|
||||
@@ -524,10 +560,18 @@ class CDCPreprocessor:
|
||||
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch
|
||||
# Compute CDC for this batch with effective k
|
||||
if self.debug:
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
if self.adaptive_k and k_effective < k_neighbors:
|
||||
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
|
||||
else:
|
||||
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
|
||||
# Temporarily override k for this bucket
|
||||
original_k = self.computer.k
|
||||
self.computer.k = k_effective
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
self.computer.k = original_k
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
if self.debug:
|
||||
|
||||
@@ -2707,6 +2707,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
) -> str:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
@@ -2751,7 +2753,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
|
||||
Reference in New Issue
Block a user