From b4e5d098711365fd1a08ef8d9a4c5f9b1818e26b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 30 Oct 2025 23:27:13 -0400 Subject: [PATCH] Fix multi-resolution support in cached files --- library/cdc_fm.py | 62 +++++++++++++++++++++----- library/flux_train_utils.py | 4 +- tests/library/test_cdc_preprocessor.py | 16 ++++--- tests/library/test_cdc_standalone.py | 25 +++++++---- 4 files changed, 78 insertions(+), 29 deletions(-) diff --git a/library/cdc_fm.py b/library/cdc_fm.py index 84a8a34a..4a5772ad 100644 --- a/library/cdc_fm.py +++ b/library/cdc_fm.py @@ -535,7 +535,11 @@ class CDCPreprocessor: self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata) @staticmethod - def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str: + def get_cdc_npz_path( + latents_npz_path: str, + config_hash: Optional[str] = None, + latent_shape: Optional[Tuple[int, ...]] = None + ) -> str: """ Get CDC cache path from latents cache path @@ -543,21 +547,48 @@ class CDCPreprocessor: configuration and CDC parameters. This prevents using stale CDC files when the dataset composition or CDC settings change. + IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure + CDC files are unique per resolution. Without it, different resolutions will overwrite + each other's CDC caches, causing dimension mismatch errors. + Args: latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz") config_hash: Optional 8-char hash of (dataset_dirs + CDC params) If None, returns path without hash (for backward compatibility) + latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific + For multi-resolution training, this MUST be provided Returns: - CDC cache path: - - With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz" - - Without: "image_0512x0768_flux_cdc.npz" + CDC cache path examples: + - With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz" + - With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz" + - Without hash: "image_0512x0768_flux_cdc.npz" + + Example multi-resolution scenario: + resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz" + resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz" """ path = Path(latents_npz_path) + + # Build filename components + components = [path.stem, "cdc"] + + # Add latent resolution if provided (for multi-resolution training) + if latent_shape is not None: + if len(latent_shape) >= 3: + # Format: HxW (e.g., "104x80" from shape (16, 104, 80)) + h, w = latent_shape[-2], latent_shape[-1] + components.append(f"{h}x{w}") + else: + raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}") + + # Add config hash if provided if config_hash: - return str(path.with_stem(f"{path.stem}_cdc_{config_hash}")) - else: - return str(path.with_stem(f"{path.stem}_cdc")) + components.append(config_hash) + + # Build final filename + new_stem = "_".join(components) + return str(path.with_stem(new_stem)) def compute_all(self) -> int: """ @@ -687,8 +718,8 @@ class CDCPreprocessor: save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples for sample in save_iter: - # Get CDC cache path with config hash - cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash) + # Get CDC cache path with config hash and latent shape (for multi-resolution support) + cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape) # Get CDC results for this sample if sample.global_idx in all_results: @@ -748,7 +779,8 @@ class GammaBDataset: def get_gamma_b_sqrt( self, latents_npz_paths: List[str], - device: Optional[str] = None + device: Optional[str] = None, + latent_shape: Optional[Tuple[int, ...]] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get Γ_b^(1/2) components for a batch of latents @@ -756,10 +788,16 @@ class GammaBDataset: Args: latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...]) device: Device to load to (defaults to self.device) + latent_shape: Latent shape (C, H, W) to identify which CDC file to load + Required for multi-resolution training to avoid loading wrong CDC Returns: eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample! eigenvalues: (B, d_cdc) + + Note: + For multi-resolution training, latent_shape MUST be provided to load the correct + CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch. """ if device is None: device = self.device @@ -768,8 +806,8 @@ class GammaBDataset: eigenvalues_list = [] for latents_npz_path in latents_npz_paths: - # Get CDC cache path with config hash - cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash) + # Get CDC cache path with config hash and latent shape (for multi-resolution support) + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape) # Load CDC data if not Path(cdc_path).exists(): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 295660c2..ca030730 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -519,7 +519,9 @@ def apply_cdc_noise_transformation( B, C, H, W = noise.shape # Batch processing: Get CDC data for all samples at once - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device) + # Pass latent shape for multi-resolution CDC support + latent_shape = (C, H, W) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape) noise_flat = noise.reshape(B, -1) noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) return noise_cdc_flat.reshape(B, C, H, W) diff --git a/tests/library/test_cdc_preprocessor.py b/tests/library/test_cdc_preprocessor.py index 21005bab..d8c92573 100644 --- a/tests/library/test_cdc_preprocessor.py +++ b/tests/library/test_cdc_preprocessor.py @@ -52,9 +52,10 @@ class TestCDCPreprocessorIntegration: # Verify files were created assert files_saved == 10 - # Verify first CDC file structure (with config hash) + # Verify first CDC file structure (with config hash and latent shape) latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") - cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash)) + latent_shape = (16, 4, 4) + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape)) assert cdc_path.exists() import numpy as np @@ -112,12 +113,12 @@ class TestCDCPreprocessorIntegration: assert files_saved == 10 import numpy as np - # Check shapes are stored in individual files (with config hash) + # Check shapes are stored in individual files (with config hash and latent shape) cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( - str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4) ) cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( - str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8) ) data_0 = np.load(cdc_path_0) data_5 = np.load(cdc_path_5) @@ -278,8 +279,9 @@ class TestCDCEndToEnd: batch_t = torch.rand(batch_size) latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] - # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu") + # Get Γ_b components (pass latent_shape for multi-resolution support) + latent_shape = (16, 4, 4) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape) # Compute geometry-aware noise sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py index 6815b4da..c5a6914a 100644 --- a/tests/library/test_cdc_standalone.py +++ b/tests/library/test_cdc_standalone.py @@ -45,7 +45,8 @@ class TestCDCPreprocessor: # Verify first CDC file structure latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") - cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash)) + latent_shape = (16, 4, 4) + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape)) assert cdc_path.exists() data = np.load(cdc_path) @@ -100,10 +101,10 @@ class TestCDCPreprocessor: # Check shapes are stored in individual files cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( - str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4) ) cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( - str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8) ) data_0 = np.load(cdc_path_0) @@ -152,7 +153,8 @@ class TestGammaBDataset: gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) # Get components for first sample - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu") + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape) # Check shapes assert eigvecs.shape[0] == 1 # batch size @@ -166,7 +168,8 @@ class TestGammaBDataset: # Get Γ_b for paths [0, 2, 4] paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] - eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") + latent_shape = (16, 8, 8) + eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) # Check shapes assert eigenvectors.shape[0] == 3 # batch @@ -187,7 +190,8 @@ class TestGammaBDataset: # Get Γ_b components paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]] - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -204,7 +208,8 @@ class TestGammaBDataset: # Get Γ_b components paths = [latents_npz_paths[1], latents_npz_paths[3]] - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -221,7 +226,8 @@ class TestGammaBDataset: # Get Γ_b components paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu") + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) @@ -269,7 +275,8 @@ class TestCDCEndToEnd: paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] # Get Γ_b components - eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu") + latent_shape = (16, 4, 4) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape) # Compute geometry-aware noise sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)