mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Fix multi-resolution support in cached files
This commit is contained in:
@@ -535,7 +535,11 @@ class CDCPreprocessor:
|
||||
self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata)
|
||||
|
||||
@staticmethod
|
||||
def get_cdc_npz_path(latents_npz_path: str, config_hash: Optional[str] = None) -> str:
|
||||
def get_cdc_npz_path(
|
||||
latents_npz_path: str,
|
||||
config_hash: Optional[str] = None,
|
||||
latent_shape: Optional[Tuple[int, ...]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get CDC cache path from latents cache path
|
||||
|
||||
@@ -543,21 +547,48 @@ class CDCPreprocessor:
|
||||
configuration and CDC parameters. This prevents using stale CDC files when
|
||||
the dataset composition or CDC settings change.
|
||||
|
||||
IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure
|
||||
CDC files are unique per resolution. Without it, different resolutions will overwrite
|
||||
each other's CDC caches, causing dimension mismatch errors.
|
||||
|
||||
Args:
|
||||
latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz")
|
||||
config_hash: Optional 8-char hash of (dataset_dirs + CDC params)
|
||||
If None, returns path without hash (for backward compatibility)
|
||||
latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific
|
||||
For multi-resolution training, this MUST be provided
|
||||
|
||||
Returns:
|
||||
CDC cache path:
|
||||
- With hash: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
|
||||
- Without: "image_0512x0768_flux_cdc.npz"
|
||||
CDC cache path examples:
|
||||
- With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz"
|
||||
- With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz"
|
||||
- Without hash: "image_0512x0768_flux_cdc.npz"
|
||||
|
||||
Example multi-resolution scenario:
|
||||
resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz"
|
||||
resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz"
|
||||
"""
|
||||
path = Path(latents_npz_path)
|
||||
|
||||
# Build filename components
|
||||
components = [path.stem, "cdc"]
|
||||
|
||||
# Add latent resolution if provided (for multi-resolution training)
|
||||
if latent_shape is not None:
|
||||
if len(latent_shape) >= 3:
|
||||
# Format: HxW (e.g., "104x80" from shape (16, 104, 80))
|
||||
h, w = latent_shape[-2], latent_shape[-1]
|
||||
components.append(f"{h}x{w}")
|
||||
else:
|
||||
raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}")
|
||||
|
||||
# Add config hash if provided
|
||||
if config_hash:
|
||||
return str(path.with_stem(f"{path.stem}_cdc_{config_hash}"))
|
||||
else:
|
||||
return str(path.with_stem(f"{path.stem}_cdc"))
|
||||
components.append(config_hash)
|
||||
|
||||
# Build final filename
|
||||
new_stem = "_".join(components)
|
||||
return str(path.with_stem(new_stem))
|
||||
|
||||
def compute_all(self) -> int:
|
||||
"""
|
||||
@@ -687,8 +718,8 @@ class CDCPreprocessor:
|
||||
save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples
|
||||
|
||||
for sample in save_iter:
|
||||
# Get CDC cache path with config hash
|
||||
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash)
|
||||
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
|
||||
cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape)
|
||||
|
||||
# Get CDC results for this sample
|
||||
if sample.global_idx in all_results:
|
||||
@@ -748,7 +779,8 @@ class GammaBDataset:
|
||||
def get_gamma_b_sqrt(
|
||||
self,
|
||||
latents_npz_paths: List[str],
|
||||
device: Optional[str] = None
|
||||
device: Optional[str] = None,
|
||||
latent_shape: Optional[Tuple[int, ...]] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get Γ_b^(1/2) components for a batch of latents
|
||||
@@ -756,10 +788,16 @@ class GammaBDataset:
|
||||
Args:
|
||||
latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...])
|
||||
device: Device to load to (defaults to self.device)
|
||||
latent_shape: Latent shape (C, H, W) to identify which CDC file to load
|
||||
Required for multi-resolution training to avoid loading wrong CDC
|
||||
|
||||
Returns:
|
||||
eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample!
|
||||
eigenvalues: (B, d_cdc)
|
||||
|
||||
Note:
|
||||
For multi-resolution training, latent_shape MUST be provided to load the correct
|
||||
CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
@@ -768,8 +806,8 @@ class GammaBDataset:
|
||||
eigenvalues_list = []
|
||||
|
||||
for latents_npz_path in latents_npz_paths:
|
||||
# Get CDC cache path with config hash
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash)
|
||||
# Get CDC cache path with config hash and latent shape (for multi-resolution support)
|
||||
cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape)
|
||||
|
||||
# Load CDC data
|
||||
if not Path(cdc_path).exists():
|
||||
|
||||
@@ -519,7 +519,9 @@ def apply_cdc_noise_transformation(
|
||||
B, C, H, W = noise.shape
|
||||
|
||||
# Batch processing: Get CDC data for all samples at once
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device)
|
||||
# Pass latent shape for multi-resolution CDC support
|
||||
latent_shape = (C, H, W)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape)
|
||||
noise_flat = noise.reshape(B, -1)
|
||||
noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized)
|
||||
return noise_cdc_flat.reshape(B, C, H, W)
|
||||
|
||||
Reference in New Issue
Block a user