Fix multi-resolution support in cached files

This commit is contained in:
rockerBOO
2025-10-30 23:27:13 -04:00
parent 0dfafb4fff
commit b4e5d09871
4 changed files with 78 additions and 29 deletions

View File

@@ -535,7 +535,11 @@ class CDCPreprocessor:
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
@staticmethod
def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str:
def get_cdc_npz_path(
latents_npz_path: str,
config_hash: Optional[str] = None,
latent_shape: Optional[Tuple[int, ...]] = None
) -> str:
"""
Get CDC cache path from latents cache path
@@ -543,21 +547,48 @@ class CDCPreprocessor:
configuration and CDC parameters. This prevents using stale CDC files when
the dataset composition or CDC settings change.
IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure
CDC files are unique per resolution. Without it, different resolutions will overwrite
each other's CDC caches, causing dimension mismatch errors.
Args:
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
If None, returns path without hash (for backward compatibility)
latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific
For multi-resolution training, this MUST be provided
Returns:
CDC cache path:
- With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
- Without: "image_0512x0768_flux_cdc.npz"
CDC cache path examples:
- With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz"
- With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
- Without hash: "image_0512x0768_flux_cdc.npz"
Example multi-resolution scenario:
resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz"
resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz"
"""
path = Path(latents_npz_path)
# Build filename components
components = [path.stem, "cdc"]
# Add latent resolution if provided (for multi-resolution training)
if latent_shape is not None:
if len(latent_shape) >= 3:
# Format: HxW (e.g., "104x80" from shape (16, 104, 80))
h, w = latent_shape[-2], latent_shape[-1]
components.append(f"{h}x{w}")
else:
raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}")
# Add config hash if provided
if config_hash:
return str(path.with_stem(f"{path.stem}_cdc_{config_hash}"))
else:
return str(path.with_stem(f"{path.stem}_cdc"))
components.append(config_hash)
# Build final filename
new_stem = "_".join(components)
return str(path.with_stem(new_stem))
def compute_all(self) -> int:
"""
@@ -687,8 +718,8 @@ class CDCPreprocessor:
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
for sample in save_iter:
# Get CDC cache path with config hash
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash)
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape)
# Get CDC results for this sample
if sample.global_idx in all_results:
@@ -748,7 +779,8 @@ class GammaBDataset:
def get_gamma_b_sqrt(
self,
latents_npz_paths: List[str],
device: Optional[str] = None
device: Optional[str] = None,
latent_shape: Optional[Tuple[int, ...]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get Γ_b^(1/2) components for a batch of latents
@@ -756,10 +788,16 @@ class GammaBDataset:
Args:
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
device: Device to load to (defaults to self.device)
latent_shape: Latent shape (C, H, W) to identify which CDC file to load
Required for multi-resolution training to avoid loading wrong CDC
Returns:
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
eigenvalues: (B, d_cdc)
Note:
For multi-resolution training, latent_shape MUST be provided to load the correct
CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch.
"""
if device is None:
device = self.device
@@ -768,8 +806,8 @@ class GammaBDataset:
eigenvalues_list = []
for latents_npz_path in latents_npz_paths:
# Get CDC cache path with config hash
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash)
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape)
# Load CDC data
if not Path(cdc_path).exists():

View File

@@ -519,7 +519,9 @@ def apply_cdc_noise_transformation(
B, C, H, W = noise.shape
# Batch processing: Get CDC data for all samples at once
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device)
# Pass latent shape for multi-resolution CDC support
latent_shape = (C, H, W)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape)
noise_flat = noise.reshape(B, -1)
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
return noise_cdc_flat.reshape(B, C, H, W)

View File

@@ -52,9 +52,10 @@ class TestCDCPreprocessorIntegration:
# Verify files were created
assert files_saved == 10
# Verify first CDC file structure (with config hash)
# Verify first CDC file structure (with config hash and latent shape)
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
import numpy as np
@@ -112,12 +113,12 @@ class TestCDCPreprocessorIntegration:
assert files_saved == 10
import numpy as np
# Check shapes are stored in individual files (with config hash)
# Check shapes are stored in individual files (with config hash and latent shape)
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
data_5 = np.load(cdc_path_5)
@@ -278,8 +279,9 @@ class TestCDCEndToEnd:
batch_t = torch.rand(batch_size)
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
# Get Γ_b components (pass latent_shape for multi-resolution support)
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)

View File

@@ -45,7 +45,8 @@ class TestCDCPreprocessor:
# Verify first CDC file structure
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
latent_shape = (16, 4, 4)
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
assert cdc_path.exists()
data = np.load(cdc_path)
@@ -100,10 +101,10 @@ class TestCDCPreprocessor:
# Check shapes are stored in individual files
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
)
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
)
data_0 = np.load(cdc_path_0)
@@ -152,7 +153,8 @@ class TestGammaBDataset:
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
# Get components for first sample
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu")
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigvecs.shape[0] == 1 # batch size
@@ -166,7 +168,8 @@ class TestGammaBDataset:
# Get Γ_b for paths [0, 2, 4]
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
latent_shape = (16, 8, 8)
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
# Check shapes
assert eigenvectors.shape[0] == 3 # batch
@@ -187,7 +190,8 @@ class TestGammaBDataset:
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -204,7 +208,8 @@ class TestGammaBDataset:
# Get Γ_b components
paths = [latents_npz_paths[1], latents_npz_paths[3]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -221,7 +226,8 @@ class TestGammaBDataset:
# Get Γ_b components
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
latent_shape = (16, 8, 8)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
@@ -269,7 +275,8 @@ class TestCDCEndToEnd:
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
# Get Γ_b components
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu")
latent_shape = (16, 4, 4)
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape)
# Compute geometry-aware noise
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)