mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add adaptive k_neighbors support for CDC-FM
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
This commit is contained in:
@@ -467,6 +467,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
|
||||
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
|
||||
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
|
||||
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
|
||||
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
@@ -593,6 +595,23 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Enable verbose CDC debug output showing bucket details"
|
||||
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_adaptive_k",
|
||||
action="store_true",
|
||||
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
|
||||
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
|
||||
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
|
||||
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_min_bucket_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
|
||||
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
|
||||
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。"
|
||||
"--cdc_adaptive_k有効時のみ関連(デフォルト: 16)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -425,7 +425,9 @@ class CDCPreprocessor:
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16
|
||||
):
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
@@ -436,6 +438,8 @@ class CDCPreprocessor:
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
self.adaptive_k = adaptive_k
|
||||
self.min_bucket_size = min_bucket_size
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
@@ -473,15 +477,23 @@ class CDCPreprocessor:
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
|
||||
|
||||
if self.adaptive_k:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
|
||||
else:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
if self.debug:
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
if self.adaptive_k:
|
||||
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
|
||||
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
else:
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
mode = "adaptive" if self.adaptive_k else "fixed"
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
@@ -497,22 +509,46 @@ class CDCPreprocessor:
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if bucket has enough samples for k-NN
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
# Determine effective k for this bucket
|
||||
if self.adaptive_k:
|
||||
# Adaptive mode: skip if below minimum, otherwise use best available k
|
||||
if num_samples < min_threshold:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
# Use adaptive k for this bucket
|
||||
k_effective = min(k_neighbors, num_samples - 1)
|
||||
else:
|
||||
# Fixed mode: skip if below k_neighbors
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
k_effective = k_neighbors
|
||||
|
||||
# Collect latents (no resizing needed - all same shape)
|
||||
latents_list = []
|
||||
@@ -524,10 +560,18 @@ class CDCPreprocessor:
|
||||
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch
|
||||
# Compute CDC for this batch with effective k
|
||||
if self.debug:
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
if self.adaptive_k and k_effective < k_neighbors:
|
||||
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
|
||||
else:
|
||||
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
|
||||
# Temporarily override k for this bucket
|
||||
original_k = self.computer.k
|
||||
self.computer.k = k_effective
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
self.computer.k = original_k
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
if self.debug:
|
||||
|
||||
@@ -2707,6 +2707,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
) -> str:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
@@ -2751,7 +2753,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
|
||||
230
tests/library/test_cdc_adaptive_k.py
Normal file
230
tests/library/test_cdc_adaptive_k.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Test adaptive k_neighbors functionality in CDC-FM.
|
||||
|
||||
Verifies that adaptive k properly adjusts based on bucket sizes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestAdaptiveK:
|
||||
"""Test adaptive k_neighbors behavior"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_cache_path(self, tmp_path):
|
||||
"""Create temporary cache path"""
|
||||
return tmp_path / "adaptive_k_test.safetensors"
|
||||
|
||||
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
|
||||
"""
|
||||
Test that fixed k mode skips buckets with < k_neighbors samples.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=False # Fixed mode
|
||||
)
|
||||
|
||||
# Add 10 samples (< k=32, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (Gaussian fallback)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (fallback)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 20 samples (< k=32, should use k=19)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify non-zero (CDC computed)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should NOT be all zeros (CDC was computed)
|
||||
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode skips buckets below min_bucket_size.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=16
|
||||
)
|
||||
|
||||
# Add 10 samples (< min_bucket_size=16, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (skipped due to min_bucket_size)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (skipped)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
|
||||
"""
|
||||
Test adaptive k with multiple buckets of different sizes.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Bucket 1: 10 samples (adaptive k=9)
|
||||
for i in range(10):
|
||||
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=(4, 16, 16),
|
||||
metadata={'image_key': f'small_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 2: 40 samples (full k=32)
|
||||
for i in range(40):
|
||||
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=100+i,
|
||||
shape=(4, 32, 32),
|
||||
metadata={'image_key': f'large_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 3: 5 samples (< min=8, skipped)
|
||||
for i in range(5):
|
||||
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=200+i,
|
||||
shape=(4, 8, 8),
|
||||
metadata={'image_key': f'tiny_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
|
||||
# Bucket 1: Should have CDC (non-zero)
|
||||
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
|
||||
|
||||
# Bucket 2: Should have CDC (non-zero)
|
||||
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
|
||||
|
||||
# Bucket 3: Should be skipped (zeros)
|
||||
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
|
||||
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
|
||||
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k uses full k_neighbors when bucket is large enough.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16,
|
||||
k_bandwidth=4,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 50 samples (> k=16, should use full k=16)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(50):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify CDC was computed
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should have non-zero eigenvalues
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
# Eigenvalues should be positive
|
||||
assert (eigvals >= 0).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -636,6 +636,8 @@ class NetworkTrainer:
|
||||
force_recache=args.force_recache_cdc,
|
||||
accelerator=accelerator,
|
||||
debug=getattr(args, 'cdc_debug', False),
|
||||
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
|
||||
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
|
||||
)
|
||||
else:
|
||||
self.cdc_cache_path = None
|
||||
|
||||
Reference in New Issue
Block a user