mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
- Cache all shapes during GammaBDataset initialization - Eliminates file I/O on every training step (9.5M accesses/sec) - Reduces get_shape() from file operation to dict lookup - Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
"""
|
|
Benchmark script to measure performance improvement from caching shapes in memory.
|
|
|
|
Simulates the get_shape() calls that happen during training.
|
|
"""
|
|
|
|
import time
|
|
import tempfile
|
|
import torch
|
|
from pathlib import Path
|
|
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
|
|
|
|
|
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
|
|
"""Create a test CDC cache file"""
|
|
preprocessor = CDCPreprocessor(
|
|
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
|
)
|
|
|
|
print(f"Creating test cache with {num_samples} samples...")
|
|
for i in range(num_samples):
|
|
latent = torch.randn(*shape, dtype=torch.float32)
|
|
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
|
|
|
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
|
|
preprocessor.compute_all(save_path=temp_file)
|
|
return temp_file
|
|
|
|
|
|
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
|
|
"""Benchmark repeated get_shape() calls"""
|
|
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
|
|
print("=" * 60)
|
|
|
|
# Load dataset (this is when caching happens)
|
|
load_start = time.time()
|
|
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
|
load_time = time.time() - load_start
|
|
print(f"Dataset load time (with caching): {load_time:.4f}s")
|
|
|
|
# Benchmark shape access
|
|
num_samples = dataset.num_samples
|
|
total_accesses = 0
|
|
|
|
start = time.time()
|
|
for iteration in range(num_iterations):
|
|
# Simulate a training batch
|
|
for _ in range(batch_size):
|
|
idx = iteration % num_samples
|
|
shape = dataset.get_shape(idx)
|
|
total_accesses += 1
|
|
|
|
elapsed = time.time() - start
|
|
|
|
print(f"\nResults:")
|
|
print(f" Total shape accesses: {total_accesses}")
|
|
print(f" Total time: {elapsed:.4f}s")
|
|
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
|
|
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
|
|
|
|
return elapsed, total_accesses
|
|
|
|
|
|
def main():
|
|
print("CDC Shape Cache Benchmark")
|
|
print("=" * 60)
|
|
|
|
# Create test cache
|
|
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
|
|
|
|
try:
|
|
# Benchmark with typical training workload
|
|
# Simulates 1000 training steps with batch_size=8
|
|
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Summary:")
|
|
print(" With in-memory caching, shape access should be:")
|
|
print(" - Sub-millisecond per access")
|
|
print(" - No disk I/O after initial load")
|
|
print(" - Constant time regardless of cache file size")
|
|
|
|
finally:
|
|
# Cleanup
|
|
if cache_path.exists():
|
|
cache_path.unlink()
|
|
print(f"\nCleaned up test file: {cache_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|