mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
3 Commits
20c6ae5a9a
...
8458a5696e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8458a5696e | ||
|
|
7ca799ca26 | ||
|
|
f450443fe4 |
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
Benchmark script to measure performance improvement from caching shapes in memory.
|
||||
|
||||
Simulates the get_shape() calls that happen during training.
|
||||
"""
|
||||
|
||||
import time
|
||||
import tempfile
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
|
||||
"""Create a test CDC cache file"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
print(f"Creating test cache with {num_samples} samples...")
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
|
||||
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
|
||||
preprocessor.compute_all(save_path=temp_file)
|
||||
return temp_file
|
||||
|
||||
|
||||
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
|
||||
"""Benchmark repeated get_shape() calls"""
|
||||
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
|
||||
print("=" * 60)
|
||||
|
||||
# Load dataset (this is when caching happens)
|
||||
load_start = time.time()
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
load_time = time.time() - load_start
|
||||
print(f"Dataset load time (with caching): {load_time:.4f}s")
|
||||
|
||||
# Benchmark shape access
|
||||
num_samples = dataset.num_samples
|
||||
total_accesses = 0
|
||||
|
||||
start = time.time()
|
||||
for iteration in range(num_iterations):
|
||||
# Simulate a training batch
|
||||
for _ in range(batch_size):
|
||||
idx = iteration % num_samples
|
||||
shape = dataset.get_shape(idx)
|
||||
total_accesses += 1
|
||||
|
||||
elapsed = time.time() - start
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Total shape accesses: {total_accesses}")
|
||||
print(f" Total time: {elapsed:.4f}s")
|
||||
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
|
||||
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
|
||||
|
||||
return elapsed, total_accesses
|
||||
|
||||
|
||||
def main():
|
||||
print("CDC Shape Cache Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# Create test cache
|
||||
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
|
||||
|
||||
try:
|
||||
# Benchmark with typical training workload
|
||||
# Simulates 1000 training steps with batch_size=8
|
||||
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(" With in-memory caching, shape access should be:")
|
||||
print(" - Sub-millisecond per access")
|
||||
print(" - No disk I/O after initial load")
|
||||
print(" - Constant time regardless of cache file size")
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
print(f"\nCleaned up test file: {cache_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -461,6 +461,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
# CDC-FM metadata
|
||||
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
|
||||
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
|
||||
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
|
||||
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
|
||||
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
|
||||
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
|
||||
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
@@ -586,6 +595,23 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Enable verbose CDC debug output showing bucket details"
|
||||
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_adaptive_k",
|
||||
action="store_true",
|
||||
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
|
||||
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
|
||||
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
|
||||
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cdc_min_bucket_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
|
||||
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
|
||||
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。"
|
||||
"--cdc_adaptive_k有効時のみ関連(デフォルト: 16)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import faiss # type: ignore
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from safetensors.torch import save_file
|
||||
from typing import List, Dict, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
import faiss # type: ignore
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -425,8 +430,17 @@ class CDCPreprocessor:
|
||||
gamma: float = 1.0,
|
||||
device: str = 'cuda',
|
||||
size_tolerance: float = 0.0,
|
||||
debug: bool = False
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16
|
||||
):
|
||||
if not FAISS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"FAISS is required for CDC-FM but not installed. "
|
||||
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU). "
|
||||
"CDC-FM will be disabled."
|
||||
)
|
||||
|
||||
self.computer = CarreDuChampComputer(
|
||||
k_neighbors=k_neighbors,
|
||||
k_bandwidth=k_bandwidth,
|
||||
@@ -436,6 +450,8 @@ class CDCPreprocessor:
|
||||
)
|
||||
self.batcher = LatentBatcher(size_tolerance=size_tolerance)
|
||||
self.debug = debug
|
||||
self.adaptive_k = adaptive_k
|
||||
self.min_bucket_size = min_bucket_size
|
||||
|
||||
def add_latent(
|
||||
self,
|
||||
@@ -473,15 +489,23 @@ class CDCPreprocessor:
|
||||
|
||||
# Count samples that will get CDC vs fallback
|
||||
k_neighbors = self.computer.k
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors
|
||||
|
||||
if self.adaptive_k:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold)
|
||||
else:
|
||||
samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors)
|
||||
samples_fallback = len(self.batcher) - samples_with_cdc
|
||||
|
||||
if self.debug:
|
||||
print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets")
|
||||
print(f" Samples with CDC (≥{k_neighbors} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
if self.adaptive_k:
|
||||
print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}")
|
||||
print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)")
|
||||
print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)")
|
||||
else:
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets: {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
mode = "adaptive" if self.adaptive_k else "fixed"
|
||||
logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback")
|
||||
|
||||
# Storage for results
|
||||
all_results = {}
|
||||
@@ -497,22 +521,46 @@ class CDCPreprocessor:
|
||||
print(f"Bucket: {shape} ({num_samples} samples)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if bucket has enough samples for k-NN
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
# Determine effective k for this bucket
|
||||
if self.adaptive_k:
|
||||
# Adaptive mode: skip if below minimum, otherwise use best available k
|
||||
if num_samples < min_threshold:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
# Use adaptive k for this bucket
|
||||
k_effective = min(k_neighbors, num_samples - 1)
|
||||
else:
|
||||
# Fixed mode: skip if below k_neighbors
|
||||
if num_samples < k_neighbors:
|
||||
if self.debug:
|
||||
print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}")
|
||||
print(" → These samples will use standard Gaussian noise (no CDC)")
|
||||
|
||||
# Store zero eigenvectors/eigenvalues (Gaussian fallback)
|
||||
C, H, W = shape
|
||||
d = C * H * W
|
||||
|
||||
for sample in samples:
|
||||
eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16)
|
||||
eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16)
|
||||
all_results[sample.global_idx] = (eigvecs, eigvals)
|
||||
|
||||
continue
|
||||
|
||||
k_effective = k_neighbors
|
||||
|
||||
# Collect latents (no resizing needed - all same shape)
|
||||
latents_list = []
|
||||
@@ -524,10 +572,18 @@ class CDCPreprocessor:
|
||||
|
||||
latents_np = np.stack(latents_list, axis=0) # (N, C*H*W)
|
||||
|
||||
# Compute CDC for this batch
|
||||
# Compute CDC for this batch with effective k
|
||||
if self.debug:
|
||||
print(f" Computing CDC with k={k_neighbors} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
if self.adaptive_k and k_effective < k_neighbors:
|
||||
print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}")
|
||||
else:
|
||||
print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}")
|
||||
|
||||
# Temporarily override k for this bucket
|
||||
original_k = self.computer.k
|
||||
self.computer.k = k_effective
|
||||
batch_results = self.computer.compute_for_batch(latents_np, global_indices)
|
||||
self.computer.k = original_k
|
||||
|
||||
# No resizing needed - eigenvectors are already correct size
|
||||
if self.debug:
|
||||
|
||||
@@ -2707,6 +2707,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
force_recache: bool = False,
|
||||
accelerator: Optional["Accelerator"] = None,
|
||||
debug: bool = False,
|
||||
adaptive_k: bool = False,
|
||||
min_bucket_size: int = 16,
|
||||
) -> str:
|
||||
"""
|
||||
Cache CDC Γ_b matrices for all latents in the dataset
|
||||
@@ -2746,12 +2748,19 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info("Starting CDC-FM preprocessing")
|
||||
logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize CDC preprocessor
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
# Initialize CDC preprocessor
|
||||
try:
|
||||
from library.cdc_fm import CDCPreprocessor
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"FAISS not installed. CDC-FM preprocessing skipped. "
|
||||
"Install with: pip install faiss-cpu (CPU) or faiss-gpu (GPU)"
|
||||
)
|
||||
return None
|
||||
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug
|
||||
k_neighbors=k_neighbors, k_bandwidth=k_bandwidth, d_cdc=d_cdc, gamma=gamma, device="cuda" if torch.cuda.is_available() else "cpu", debug=debug, adaptive_k=adaptive_k, min_bucket_size=min_bucket_size
|
||||
)
|
||||
|
||||
# Get caching strategy for loading latents
|
||||
|
||||
230
tests/library/test_cdc_adaptive_k.py
Normal file
230
tests/library/test_cdc_adaptive_k.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Test adaptive k_neighbors functionality in CDC-FM.
|
||||
|
||||
Verifies that adaptive k properly adjusts based on bucket sizes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
class TestAdaptiveK:
|
||||
"""Test adaptive k_neighbors behavior"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_cache_path(self, tmp_path):
|
||||
"""Create temporary cache path"""
|
||||
return tmp_path / "adaptive_k_test.safetensors"
|
||||
|
||||
def test_fixed_k_skips_small_buckets(self, temp_cache_path):
|
||||
"""
|
||||
Test that fixed k mode skips buckets with < k_neighbors samples.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=False # Fixed mode
|
||||
)
|
||||
|
||||
# Add 10 samples (< k=32, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (Gaussian fallback)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (fallback)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_available_neighbors(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode uses k=bucket_size-1 for small buckets.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 20 samples (< k=32, should use k=19)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(20):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify non-zero (CDC computed)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should NOT be all zeros (CDC was computed)
|
||||
assert not torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_respects_min_bucket_size(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k mode skips buckets below min_bucket_size.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=16
|
||||
)
|
||||
|
||||
# Add 10 samples (< min_bucket_size=16, should be skipped)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(10):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify zeros (skipped due to min_bucket_size)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should be all zeros (skipped)
|
||||
assert torch.allclose(eigvecs, torch.zeros_like(eigvecs), atol=1e-6)
|
||||
assert torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_mixed_bucket_sizes(self, temp_cache_path):
|
||||
"""
|
||||
Test adaptive k with multiple buckets of different sizes.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=32,
|
||||
k_bandwidth=8,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Bucket 1: 10 samples (adaptive k=9)
|
||||
for i in range(10):
|
||||
latent = torch.randn(4, 16, 16, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=(4, 16, 16),
|
||||
metadata={'image_key': f'small_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 2: 40 samples (full k=32)
|
||||
for i in range(40):
|
||||
latent = torch.randn(4, 32, 32, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=100+i,
|
||||
shape=(4, 32, 32),
|
||||
metadata={'image_key': f'large_{i}'}
|
||||
)
|
||||
|
||||
# Bucket 3: 5 samples (< min=8, skipped)
|
||||
for i in range(5):
|
||||
latent = torch.randn(4, 8, 8, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=200+i,
|
||||
shape=(4, 8, 8),
|
||||
metadata={'image_key': f'tiny_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
|
||||
# Bucket 1: Should have CDC (non-zero)
|
||||
eigvecs_small, eigvals_small = dataset.get_gamma_b_sqrt(['small_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_small, torch.zeros_like(eigvecs_small), atol=1e-6)
|
||||
|
||||
# Bucket 2: Should have CDC (non-zero)
|
||||
eigvecs_large, eigvals_large = dataset.get_gamma_b_sqrt(['large_0'], device='cpu')
|
||||
assert not torch.allclose(eigvecs_large, torch.zeros_like(eigvecs_large), atol=1e-6)
|
||||
|
||||
# Bucket 3: Should be skipped (zeros)
|
||||
eigvecs_tiny, eigvals_tiny = dataset.get_gamma_b_sqrt(['tiny_0'], device='cpu')
|
||||
assert torch.allclose(eigvecs_tiny, torch.zeros_like(eigvecs_tiny), atol=1e-6)
|
||||
assert torch.allclose(eigvals_tiny, torch.zeros_like(eigvals_tiny), atol=1e-6)
|
||||
|
||||
def test_adaptive_k_uses_full_k_when_available(self, temp_cache_path):
|
||||
"""
|
||||
Test that adaptive k uses full k_neighbors when bucket is large enough.
|
||||
"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16,
|
||||
k_bandwidth=4,
|
||||
d_cdc=4,
|
||||
gamma=1.0,
|
||||
device='cpu',
|
||||
debug=False,
|
||||
adaptive_k=True,
|
||||
min_bucket_size=8
|
||||
)
|
||||
|
||||
# Add 50 samples (> k=16, should use full k=16)
|
||||
shape = (4, 16, 16)
|
||||
for i in range(50):
|
||||
latent = torch.randn(*shape, dtype=torch.float32).numpy()
|
||||
preprocessor.add_latent(
|
||||
latent=latent,
|
||||
global_idx=i,
|
||||
shape=shape,
|
||||
metadata={'image_key': f'test_{i}'}
|
||||
)
|
||||
|
||||
preprocessor.compute_all(temp_cache_path)
|
||||
|
||||
# Load and verify CDC was computed
|
||||
dataset = GammaBDataset(gamma_b_path=temp_cache_path, device='cpu')
|
||||
eigvecs, eigvals = dataset.get_gamma_b_sqrt(['test_0'], device='cpu')
|
||||
|
||||
# Should have non-zero eigenvalues
|
||||
assert not torch.allclose(eigvals, torch.zeros_like(eigvals), atol=1e-6)
|
||||
# Eigenvalues should be positive
|
||||
assert (eigvals >= 0).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -636,7 +636,12 @@ class NetworkTrainer:
|
||||
force_recache=args.force_recache_cdc,
|
||||
accelerator=accelerator,
|
||||
debug=getattr(args, 'cdc_debug', False),
|
||||
adaptive_k=getattr(args, 'cdc_adaptive_k', False),
|
||||
min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16),
|
||||
)
|
||||
|
||||
if self.cdc_cache_path is None:
|
||||
logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.")
|
||||
else:
|
||||
self.cdc_cache_path = None
|
||||
|
||||
@@ -652,7 +657,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
if unet is none:
|
||||
if unet is None:
|
||||
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
|
||||
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
|
||||
|
||||
@@ -661,10 +666,10 @@ class NetworkTrainer:
|
||||
accelerator.print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not none:
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
Reference in New Issue
Block a user