mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Optimize: Cache CDC shapes in memory to eliminate I/O bottleneck
- Cache all shapes during GammaBDataset initialization - Eliminates file I/O on every training step (9.5M accesses/sec) - Reduces get_shape() from file operation to dict lookup - Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
This commit is contained in:
91
benchmark_cdc_shape_cache.py
Normal file
91
benchmark_cdc_shape_cache.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Benchmark script to measure performance improvement from caching shapes in memory.
|
||||
|
||||
Simulates the get_shape() calls that happen during training.
|
||||
"""
|
||||
|
||||
import time
|
||||
import tempfile
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
|
||||
"""Create a test CDC cache file"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
print(f"Creating test cache with {num_samples} samples...")
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
|
||||
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
|
||||
preprocessor.compute_all(save_path=temp_file)
|
||||
return temp_file
|
||||
|
||||
|
||||
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
|
||||
"""Benchmark repeated get_shape() calls"""
|
||||
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
|
||||
print("=" * 60)
|
||||
|
||||
# Load dataset (this is when caching happens)
|
||||
load_start = time.time()
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
load_time = time.time() - load_start
|
||||
print(f"Dataset load time (with caching): {load_time:.4f}s")
|
||||
|
||||
# Benchmark shape access
|
||||
num_samples = dataset.num_samples
|
||||
total_accesses = 0
|
||||
|
||||
start = time.time()
|
||||
for iteration in range(num_iterations):
|
||||
# Simulate a training batch
|
||||
for _ in range(batch_size):
|
||||
idx = iteration % num_samples
|
||||
shape = dataset.get_shape(idx)
|
||||
total_accesses += 1
|
||||
|
||||
elapsed = time.time() - start
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Total shape accesses: {total_accesses}")
|
||||
print(f" Total time: {elapsed:.4f}s")
|
||||
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
|
||||
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
|
||||
|
||||
return elapsed, total_accesses
|
||||
|
||||
|
||||
def main():
|
||||
print("CDC Shape Cache Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# Create test cache
|
||||
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
|
||||
|
||||
try:
|
||||
# Benchmark with typical training workload
|
||||
# Simulates 1000 training steps with batch_size=8
|
||||
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(" With in-memory caching, shape access should be:")
|
||||
print(" - Sub-millisecond per access")
|
||||
print(" - No disk I/O after initial load")
|
||||
print(" - Constant time regardless of cache file size")
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
print(f"\nCleaned up test file: {cache_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user