Compare commits

...

3 Commits

Author SHA1 Message Date
rockerBOO
8458a5696e Add graceful fallback when FAISS is not installed
- Make FAISS import optional with try/except
- CDCPreprocessor raises helpful ImportError if FAISS unavailable
- train_util.py catches ImportError and returns None
- train_network.py checks for None and warns user
- Training continues without CDC-FM if FAISS not installed
- Remove benchmark file (not needed in repo)

This allows users to run training without FAISS dependency.
CDC-FM will be automatically disabled with a warning if FAISS is missing.
2025-10-09 23:50:07 -04:00
rockerBOO
7ca799ca26 Add adaptive k_neighbors support for CDC-FM
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size
- Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16)
- Fixed mode (default): Skip buckets with < k_neighbors samples
- Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size
- Update CDCPreprocessor to support adaptive k per bucket
- Add metadata tracking for adaptive_k and min_bucket_size
- Add comprehensive pytest tests for adaptive k behavior

This allows CDC-FM to work effectively with multi-resolution bucketing where
bucket sizes may vary widely. Users can choose between strict paper methodology
(fixed k) or pragmatic approach (adaptive k).
2025-10-09 23:16:44 -04:00
rockerBOO
f450443fe4 Add CDC-FM parameters to model metadata
- Add ss_use_cdc_fm, ss_cdc_k_neighbors, ss_cdc_k_bandwidth, ss_cdc_d_cdc, ss_cdc_gamma
- Ensures CDC-FM training parameters are tracked in model metadata
- Enables reproducibility and model provenance tracking
2025-10-09 22:51:47 -04:00
6 changed files with 352 additions and 117 deletions

View File

@@ -1,91 +0,0 @@
"""
Benchmark script to measure performance improvement from caching shapes in memory.
Simulates the get_shape() calls that happen during training.
"""
import time
import tempfile
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
"""Create a test CDC cache file"""
preprocessor = CDCPreprocessor(
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
print(f"Creating test cache with {num_samples} samples...")
for i in range(num_samples):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
preprocessor.compute_all(save_path=temp_file)
return temp_file
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
"""Benchmark repeated get_shape() calls"""
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
print("=" * 60)
# Load dataset (this is when caching happens)
load_start = time.time()
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
load_time = time.time() - load_start
print(f"Dataset load time (with caching): {load_time:.4f}s")
# Benchmark shape access
num_samples = dataset.num_samples
total_accesses = 0
start = time.time()
for iteration in range(num_iterations):
# Simulate a training batch
for _ in range(batch_size):
idx = iteration % num_samples
shape = dataset.get_shape(idx)
total_accesses += 1
elapsed = time.time() - start
print(f"\nResults:")
print(f" Total shape accesses: {total_accesses}")
print(f" Total time: {elapsed:.4f}s")
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
return elapsed, total_accesses
def main():
print("CDC Shape Cache Benchmark")
print("=" * 60)
# Create test cache
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
try:
# Benchmark with typical training workload
# Simulates 1000 training steps with batch_size=8
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
print("\n" + "=" * 60)
print("Summary:")
print(" With in-memory caching, shape access should be:")
print(" - Sub-millisecond per access")
print(" - No disk I/O after initial load")
print(" - Constant time regardless of cache file size")
finally:
# Cleanup
if cache_path.exists():
cache_path.unlink()
print(f"\nCleaned up test file: {cache_path}")
if __name__ == "__main__":
main()

View File

@@ -461,6 +461,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
# CDC-FM metadata
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
@@ -586,6 +595,23 @@ def setup_parser() -> argparse.ArgumentParser:
help="Enable verbose CDC debug output showing bucket details"
" / CDCの詳細デバッグ出力を有効化バケット詳細表示",
)
parser.add_argument(
"--cdc_adaptive_k",
action="store_true",
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
)
parser.add_argument(
"--cdc_min_bucket_size",
type=int,
default=16,
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスイズを使用。"
"--cdc_adaptive_k有効時のみ関連デフォルト: 16",
)
return parser

View File

@@ -1,13 +1,18 @@
import logging
import torch
import numpy as np
import faiss # type: ignore
from pathlib import Path
from tqdm import tqdm
from safetensors.torch import save_file
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass
try:
import faiss # type: ignore
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger = logging.getLogger(__name__)
@@ -425,8 +430,17 @@ class CDCPreprocessor:
gamma: float = 1.0,
device: str = 'cuda',
size_tolerance: float = 0.0,
debug: bool = False
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16
):
if not FAISS_AVAILABLE:
raise ImportError(
"FAISS is required for CDC-FM but not installed. "
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). "
"CDC-FM will be disabled."
)
self.computer = CarreDuChampComputer(
k_neighbors=k_neighbors,
k_bandwidth=k_bandwidth,
@@ -436,6 +450,8 @@ class CDCPreprocessor:
)
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
self.debug = debug
self.adaptive_k = adaptive_k
self.min_bucket_size = min_bucket_size
def add_latent(
self,
@@ -473,15 +489,23 @@ class CDCPreprocessor:
# Count samples that will get CDC vs fallback
k_neighbors = self.computer.k
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
if self.adaptive_k:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
else:
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
samples_fallback = len(self.batcher) - samples_with_cdc
if self.debug:
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
if self.adaptive_k:
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
else:
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
mode = "adaptive" if self.adaptive_k else "fixed"
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
# Storage for results
all_results = {}
@@ -497,22 +521,46 @@ class CDCPreprocessor:
print(f"Bucket: {shape} ({num_samples} samples)")
print(f"{'='*60}")
# Check if bucket has enough samples for k-NN
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Determine effective k for this bucket
if self.adaptive_k:
# Adaptive mode: skip if below minimum, otherwise use best available k
if num_samples < min_threshold:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
continue
# Use adaptive k for this bucket
k_effective = min(k_neighbors, num_samples - 1)
else:
# Fixed mode: skip if below k_neighbors
if num_samples < k_neighbors:
if self.debug:
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
print(" → These samples will use standard Gaussian noise (no CDC)")
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
C, H, W = shape
d = C * H * W
for sample in samples:
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
all_results[sample.global_idx] = (eigvecs, eigvals)
continue
k_effective = k_neighbors
# Collect latents (no resizing needed - all same shape)
latents_list = []
@@ -524,10 +572,18 @@ class CDCPreprocessor:
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
# Compute CDC for this batch
# Compute CDC for this batch with effective k
if self.debug:
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
if self.adaptive_k and k_effective < k_neighbors:
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
else:
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
# Temporarily override k for this bucket
original_k = self.computer.k
self.computer.k = k_effective
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
self.computer.k = original_k
# No resizing needed - eigenvectors are already correct size
if self.debug:

View File

@@ -2707,6 +2707,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
force_recache: bool = False,
accelerator: Optional["Accelerator"] = None,
debug: bool = False,
adaptive_k: bool = False,
min_bucket_size: int = 16,
) -> str:
"""
Cache CDC Γ_b matrices for all latents in the dataset
@@ -2746,12 +2748,19 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info("Starting CDC-FM preprocessing")
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
logger.info("=" * 60)
# Initialize CDC preprocessor
from library.cdc_fm import CDCPreprocessor
# Initialize CDC preprocessor
try:
from library.cdc_fm import CDCPreprocessor
except ImportError as e:
logger.warning(
"FAISS not installed. CDC-FM preprocessing skipped. "
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU)"
)
return None
preprocessor = CDCPreprocessor(
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
)
# Get caching strategy for loading latents

View File

@@ -0,0 +1,230 @@
"""
Test adaptive k_neighbors functionality in CDC-FM.
Verifies that adaptive k properly adjusts based on bucket sizes.
"""
import pytest
import torch
import numpy as np
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
class TestAdaptiveK:
"""Test adaptive k_neighbors behavior"""
@pytest.fixture
def temp_cache_path(self, tmp_path):
"""Create temporary cache path"""
return tmp_path / "adaptive_k_test.safetensors"
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
"""
Test that fixed k mode skips buckets with < k_neighbors samples.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=False # Fixed mode
)
# Add 10 samples (< k=32, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (Gaussian fallback)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (fallback)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
"""
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 20 samples (< k=32, should use k=19)
shape = (4, 16, 16)
for i in range(20):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify non-zero (CDC computed)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should NOT be all zeros (CDC was computed)
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
"""
Test that adaptive k mode skips buckets below min_bucket_size.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=16
)
# Add 10 samples (< min_bucket_size=16, should be skipped)
shape = (4, 16, 16)
for i in range(10):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify zeros (skipped due to min_bucket_size)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should be all zeros (skipped)
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
"""
Test adaptive k with multiple buckets of different sizes.
"""
preprocessor = CDCPreprocessor(
k_neighbors=32,
k_bandwidth=8,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Bucket 1: 10 samples (adaptive k=9)
for i in range(10):
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=(4, 16, 16),
metadata={'image_key': f'small_{i}'}
)
# Bucket 2: 40 samples (full k=32)
for i in range(40):
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=100+i,
shape=(4, 32, 32),
metadata={'image_key': f'large_{i}'}
)
# Bucket 3: 5 samples (< min=8, skipped)
for i in range(5):
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=200+i,
shape=(4, 8, 8),
metadata={'image_key': f'tiny_{i}'}
)
preprocessor.compute_all(temp_cache_path)
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
# Bucket 1: Should have CDC (non-zero)
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
# Bucket 2: Should have CDC (non-zero)
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
# Bucket 3: Should be skipped (zeros)
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
"""
Test that adaptive k uses full k_neighbors when bucket is large enough.
"""
preprocessor = CDCPreprocessor(
k_neighbors=16,
k_bandwidth=4,
d_cdc=4,
gamma=1.0,
device='cpu',
debug=False,
adaptive_k=True,
min_bucket_size=8
)
# Add 50 samples (> k=16, should use full k=16)
shape = (4, 16, 16)
for i in range(50):
latent = torch.randn(*shape, dtype=torch.float32).numpy()
preprocessor.add_latent(
latent=latent,
global_idx=i,
shape=shape,
metadata={'image_key': f'test_{i}'}
)
preprocessor.compute_all(temp_cache_path)
# Load and verify CDC was computed
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
# Should have non-zero eigenvalues
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
# Eigenvalues should be positive
assert (eigvals >= 0).all()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -636,7 +636,12 @@ class NetworkTrainer:
force_recache=args.force_recache_cdc,
accelerator=accelerator,
debug=getattr(args, 'cdc_debug', False),
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
)
if self.cdc_cache_path is None:
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
else:
self.cdc_cache_path = None
@@ -652,7 +657,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is none:
if unet is None:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
@@ -661,10 +666,10 @@ class NetworkTrainer:
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not none:
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]