mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Optimize: Cache CDC shapes in memory to eliminate I/O bottleneck
- Cache all shapes during GammaBDataset initialization - Eliminates file I/O on every training step (9.5M accesses/sec) - Reduces get_shape() from file operation to dict lookup - Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
This commit is contained in:
91
benchmark_cdc_shape_cache.py
Normal file
91
benchmark_cdc_shape_cache.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Benchmark script to measure performance improvement from caching shapes in memory.
|
||||
|
||||
Simulates the get_shape() calls that happen during training.
|
||||
"""
|
||||
|
||||
import time
|
||||
import tempfile
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from library.cdc_fm import CDCPreprocessor, GammaBDataset
|
||||
|
||||
|
||||
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
|
||||
"""Create a test CDC cache file"""
|
||||
preprocessor = CDCPreprocessor(
|
||||
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
|
||||
)
|
||||
|
||||
print(f"Creating test cache with {num_samples} samples...")
|
||||
for i in range(num_samples):
|
||||
latent = torch.randn(*shape, dtype=torch.float32)
|
||||
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
|
||||
|
||||
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
|
||||
preprocessor.compute_all(save_path=temp_file)
|
||||
return temp_file
|
||||
|
||||
|
||||
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
|
||||
"""Benchmark repeated get_shape() calls"""
|
||||
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
|
||||
print("=" * 60)
|
||||
|
||||
# Load dataset (this is when caching happens)
|
||||
load_start = time.time()
|
||||
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
|
||||
load_time = time.time() - load_start
|
||||
print(f"Dataset load time (with caching): {load_time:.4f}s")
|
||||
|
||||
# Benchmark shape access
|
||||
num_samples = dataset.num_samples
|
||||
total_accesses = 0
|
||||
|
||||
start = time.time()
|
||||
for iteration in range(num_iterations):
|
||||
# Simulate a training batch
|
||||
for _ in range(batch_size):
|
||||
idx = iteration % num_samples
|
||||
shape = dataset.get_shape(idx)
|
||||
total_accesses += 1
|
||||
|
||||
elapsed = time.time() - start
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Total shape accesses: {total_accesses}")
|
||||
print(f" Total time: {elapsed:.4f}s")
|
||||
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
|
||||
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
|
||||
|
||||
return elapsed, total_accesses
|
||||
|
||||
|
||||
def main():
|
||||
print("CDC Shape Cache Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# Create test cache
|
||||
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
|
||||
|
||||
try:
|
||||
# Benchmark with typical training workload
|
||||
# Simulates 1000 training steps with batch_size=8
|
||||
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary:")
|
||||
print(" With in-memory caching, shape access should be:")
|
||||
print(" - Sub-millisecond per access")
|
||||
print(" - No disk I/O after initial load")
|
||||
print(" - Constant time regardless of cache file size")
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
print(f"\nCleaned up test file: {cache_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -576,12 +576,20 @@ class GammaBDataset:
|
||||
# Load metadata
|
||||
print(f"Loading Γ_b from {gamma_b_path}...")
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||||
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
|
||||
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
|
||||
|
||||
|
||||
# Cache all shapes in memory to avoid repeated I/O during training
|
||||
# Loading once at init is much faster than opening the file every training step
|
||||
self.shapes_cache = {}
|
||||
for idx in range(self.num_samples):
|
||||
shape_tensor = f.get_tensor(f'shapes/{idx}')
|
||||
self.shapes_cache[idx] = tuple(shape_tensor.numpy().tolist())
|
||||
|
||||
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
|
||||
print(f"Cached {len(self.shapes_cache)} shapes in memory")
|
||||
|
||||
@torch.no_grad()
|
||||
def get_gamma_b_sqrt(
|
||||
@@ -644,12 +652,8 @@ class GammaBDataset:
|
||||
return eigenvectors, eigenvalues
|
||||
|
||||
def get_shape(self, idx: int) -> Tuple[int, ...]:
|
||||
"""Get the original shape for a sample"""
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
|
||||
shape_tensor = f.get_tensor(f'shapes/{idx}')
|
||||
return tuple(shape_tensor.numpy().tolist())
|
||||
"""Get the original shape for a sample (cached in memory)"""
|
||||
return self.shapes_cache[idx]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_sigma_t_x(
|
||||
|
||||
Reference in New Issue
Block a user