diff --git a/benchmark_cdc_shape_cache.py b/benchmark_cdc_shape_cache.py new file mode 100644 index 00000000..d2d26ce8 --- /dev/null +++ b/benchmark_cdc_shape_cache.py @@ -0,0 +1,91 @@ +""" +Benchmark script to measure performance improvement from caching shapes in memory. + +Simulates the get_shape() calls that happen during training. +""" + +import time +import tempfile +import torch +from pathlib import Path +from library.cdc_fm import CDCPreprocessor, GammaBDataset + + +def create_test_cache(num_samples=500, shape=(16, 64, 64)): + """Create a test CDC cache file""" + preprocessor = CDCPreprocessor( + k_neighbors=16, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu" + ) + + print(f"Creating test cache with {num_samples} samples...") + for i in range(num_samples): + latent = torch.randn(*shape, dtype=torch.float32) + preprocessor.add_latent(latent=latent, global_idx=i, shape=shape) + + temp_file = Path(tempfile.mktemp(suffix=".safetensors")) + preprocessor.compute_all(save_path=temp_file) + return temp_file + + +def benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8): + """Benchmark repeated get_shape() calls""" + print(f"\nBenchmarking {num_iterations} iterations with batch_size={batch_size}") + print("=" * 60) + + # Load dataset (this is when caching happens) + load_start = time.time() + dataset = GammaBDataset(gamma_b_path=cache_path, device="cpu") + load_time = time.time() - load_start + print(f"Dataset load time (with caching): {load_time:.4f}s") + + # Benchmark shape access + num_samples = dataset.num_samples + total_accesses = 0 + + start = time.time() + for iteration in range(num_iterations): + # Simulate a training batch + for _ in range(batch_size): + idx = iteration % num_samples + shape = dataset.get_shape(idx) + total_accesses += 1 + + elapsed = time.time() - start + + print(f"\nResults:") + print(f" Total shape accesses: {total_accesses}") + print(f" Total time: {elapsed:.4f}s") + print(f" Average per access: {elapsed / total_accesses * 1000:.4f}ms") + print(f" Throughput: {total_accesses / elapsed:.1f} accesses/sec") + + return elapsed, total_accesses + + +def main(): + print("CDC Shape Cache Benchmark") + print("=" * 60) + + # Create test cache + cache_path = create_test_cache(num_samples=500, shape=(16, 64, 64)) + + try: + # Benchmark with typical training workload + # Simulates 1000 training steps with batch_size=8 + benchmark_shape_access(cache_path, num_iterations=1000, batch_size=8) + + print("\n" + "=" * 60) + print("Summary:") + print(" With in-memory caching, shape access should be:") + print(" - Sub-millisecond per access") + print(" - No disk I/O after initial load") + print(" - Constant time regardless of cache file size") + + finally: + # Cleanup + if cache_path.exists(): + cache_path.unlink() + print(f"\nCleaned up test file: {cache_path}") + + +if __name__ == "__main__": + main() diff --git a/library/cdc_fm.py b/library/cdc_fm.py index ca9f6e81..564afb82 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -576,12 +576,20 @@ class GammaBDataset: # Load metadata print(f"Loading Γ_b from {gamma_b_path}...") from safetensors import safe_open - + with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f: self.num_samples = int(f.get_tensor('metadata/num_samples').item()) self.d_cdc = int(f.get_tensor('metadata/d_cdc').item()) - + + # Cache all shapes in memory to avoid repeated I/O during training + # Loading once at init is much faster than opening the file every training step + self.shapes_cache = {} + for idx in range(self.num_samples): + shape_tensor = f.get_tensor(f'shapes/{idx}') + self.shapes_cache[idx] = tuple(shape_tensor.numpy().tolist()) + print(f"Loaded CDC data for {self.num_samples} samples (d_cdc={self.d_cdc})") + print(f"Cached {len(self.shapes_cache)} shapes in memory") @torch.no_grad() def get_gamma_b_sqrt( @@ -644,12 +652,8 @@ class GammaBDataset: return eigenvectors, eigenvalues def get_shape(self, idx: int) -> Tuple[int, ...]: - """Get the original shape for a sample""" - from safetensors import safe_open - - with safe_open(str(self.gamma_b_path), framework="pt", device="cpu") as f: - shape_tensor = f.get_tensor(f'shapes/{idx}') - return tuple(shape_tensor.numpy().tolist()) + """Get the original shape for a sample (cached in memory)""" + return self.shapes_cache[idx] @torch.no_grad() def compute_sigma_t_x(