Fix multi-resolution support in cached files

This commit is contained in:
rockerBOO
2025-10-30 23:27:13 -04:00
parent 0dfafb4fff
commit b4e5d09871
4 changed files with 78 additions and 29 deletions

View File

@@ -52,9 +52,10 @@ class TestCDCPreprocessorIntegration:
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure (with config hash)
# Verify first CDC file structure (with config hash and latent shape)
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
import numpy as np
@@ -112,12 +113,12 @@ class TestCDCPreprocessorIntegration:
assert files_saved == 10
import numpy as np
# Check shapes are stored in individual files (with config hash)
# Check shapes are stored in individual files (with config hash and latent shape)
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
@@ -278,8 +279,9 @@ class TestCDCEndToEnd:
batch_t = torch.rand(batch_size)
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
# Get Γ_b components (pass latent_shape for multi-resolution support)
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)