Files
Kohya-ss-sd-scripts/benchmark_cdc_shape_cache.py
rockerBOO e03200bdba Optimize: Cache CDC shapes in memory to eliminate I/O bottleneck
- Cache all shapes during GammaBDataset initialization
- Eliminates file I/O on every training step (9.5M accesses/sec)
- Reduces get_shape() from file operation to dict lookup
- Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
2025-10-09 18:28:50 -04:00

92 lines
2.8 KiB
Python

"""
Benchmark script to measure performance improvement from caching shapes in memory.
Simulates the get_shape() calls that happen during training.
"""
import time
import tempfile
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
"""Create a test CDC cache file"""
preprocessor = CDCPreprocessor(
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
print(f"Creating test cache with {num_samples} samples...")
for i in range(num_samples):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
preprocessor.compute_all(save_path=temp_file)
return temp_file
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
"""Benchmark repeated get_shape() calls"""
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
print("=" * 60)
# Load dataset (this is when caching happens)
load_start = time.time()
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
load_time = time.time() - load_start
print(f"Dataset load time (with caching): {load_time:.4f}s")
# Benchmark shape access
num_samples = dataset.num_samples
total_accesses = 0
start = time.time()
for iteration in range(num_iterations):
# Simulate a training batch
for _ in range(batch_size):
idx = iteration % num_samples
shape = dataset.get_shape(idx)
total_accesses += 1
elapsed = time.time() - start
print(f"\nResults:")
print(f" Total shape accesses: {total_accesses}")
print(f" Total time: {elapsed:.4f}s")
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
return elapsed, total_accesses
def main():
print("CDC Shape Cache Benchmark")
print("=" * 60)
# Create test cache
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
try:
# Benchmark with typical training workload
# Simulates 1000 training steps with batch_size=8
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
print("\n" + "=" * 60)
print("Summary:")
print(" With in-memory caching, shape access should be:")
print(" - Sub-millisecond per access")
print(" - No disk I/O after initial load")
print(" - Constant time regardless of cache file size")
finally:
# Cleanup
if cache_path.exists():
cache_path.unlink()
print(f"\nCleaned up test file: {cache_path}")
if __name__ == "__main__":
main()