Optimize: Cache CDC shapes in memory to eliminate I/O bottleneck

- Cache all shapes during GammaBDataset initialization
- Eliminates file I/O on every training step (9.5M accesses/sec)
- Reduces get_shape() from file operation to dict lookup
- Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
This commit is contained in:
rockerBOO
2025-10-09 15:27:34 -04:00
parent f552f9a3bd
commit e03200bdba
2 changed files with 103 additions and 8 deletions

View File

@@ -0,0 +1,91 @@
"""
Benchmark script to measure performance improvement from caching shapes in memory.
Simulates the get_shape() calls that happen during training.
"""
import time
import tempfile
import torch
from pathlib import Path
from library.cdc_fm import CDCPreprocessor, GammaBDataset
def create_test_cache(num_samples=500, shape=(16, 64, 64)):
"""Create a test CDC cache file"""
preprocessor = CDCPreprocessor(
k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu"
)
print(f"Creating test cache with {num_samples} samples...")
for i in range(num_samples):
latent = torch.randn(*shape, dtype=torch.float32)
preprocessor.add_latent(latent=latent, global_idx=i, shape=shape)
temp_file = Path(tempfile.mktemp(suffix=".safetensors"))
preprocessor.compute_all(save_path=temp_file)
return temp_file
def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8):
"""Benchmark repeated get_shape() calls"""
print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}")
print("=" * 60)
# Load dataset (this is when caching happens)
load_start = time.time()
dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu")
load_time = time.time() - load_start
print(f"Dataset load time (with caching): {load_time:.4f}s")
# Benchmark shape access
num_samples = dataset.num_samples
total_accesses = 0
start = time.time()
for iteration in range(num_iterations):
# Simulate a training batch
for _ in range(batch_size):
idx = iteration % num_samples
shape = dataset.get_shape(idx)
total_accesses += 1
elapsed = time.time() - start
print(f"\nResults:")
print(f" Total shape accesses: {total_accesses}")
print(f" Total time: {elapsed:.4f}s")
print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms")
print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec")
return elapsed, total_accesses
def main():
print("CDC Shape Cache Benchmark")
print("=" * 60)
# Create test cache
cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64))
try:
# Benchmark with typical training workload
# Simulates 1000 training steps with batch_size=8
benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8)
print("\n" + "=" * 60)
print("Summary:")
print(" With in-memory caching, shape access should be:")
print(" - Sub-millisecond per access")
print(" - No disk I/O after initial load")
print(" - Constant time regardless of cache file size")
finally:
# Cleanup
if cache_path.exists():
cache_path.unlink()
print(f"\nCleaned up test file: {cache_path}")
if __name__ == "__main__":
main()

View File

@@ -576,12 +576,20 @@ class GammaBDataset:
# Load metadata
print(f"Loading Γ_b from {gamma_b_path}...")
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
self.num_samples = int(f.get_tensor('metadata/num_samples').item())
self.d_cdc = int(f.get_tensor('metadata/d_cdc').item())
# Cache all shapes in memory to avoid repeated I/O during training
# Loading once at init is much faster than opening the file every training step
self.shapes_cache = {}
for idx in range(self.num_samples):
shape_tensor = f.get_tensor(f'shapes/{idx}')
self.shapes_cache[idx] = tuple(shape_tensor.numpy().tolist())
print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})")
print(f"Cached {len(self.shapes_cache)} shapes in memory")
@torch.no_grad()
def get_gamma_b_sqrt(
@@ -644,12 +652,8 @@ class GammaBDataset:
return eigenvectors, eigenvalues
def get_shape(self, idx: int) -> Tuple[int, ...]:
"""Get the original shape for a sample"""
from safetensors import safe_open
with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f:
shape_tensor = f.get_tensor(f'shapes/{idx}')
return tuple(shape_tensor.numpy().tolist())
"""Get the original shape for a sample (cached in memory)"""
return self.shapes_cache[idx]
@torch.no_grad()
def compute_sigma_t_x(