From f90fa1a89a717093286dd784c268811883f5c345 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:44:51 +0900 Subject: [PATCH] feat: backward compatibility for SD/SDXL latent cache (#2276) * fix: improve handling of legacy npz files and add logging for fallback scenarios * fix: simplify fallback handling in SdSdxlLatentsCachingStrategy --- library/strategy_base.py | 88 ++++++++++++++-------------------------- library/strategy_sd.py | 23 +++-------- 2 files changed, 37 insertions(+), 74 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 9a2acdba..5a043342 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -6,11 +6,6 @@ from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch - -try: - from numpy.lib import _format_impl as np_format_impl -except ImportError: - from numpy.lib import format as np_format_impl from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection @@ -387,6 +382,8 @@ class LatentsCachingStrategy: _strategy = None # strategy instance: actual strategy class + _warned_fallback_to_old_npz = False # to avoid spamming logs about fallback + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size @@ -429,16 +426,6 @@ class LatentsCachingStrategy: def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError - def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]: - """Get array shape in npz file by only reading the header.""" - if key not in npz: - return None - - with npz.zip.open(key + ".npy") as npy_file: - version = np.lib.format.read_magic(npy_file) - shape, _, _ = np_format_impl._read_array_header(npy_file, version) - return shape - def _default_is_disk_cached_latents_expected( self, latents_stride: int, @@ -447,7 +434,6 @@ class LatentsCachingStrategy: flip_aug: bool, apply_alpha_mask: bool, multi_resolution: bool = False, - fallback_no_reso: bool = False, ) -> bool: """ Args: @@ -457,7 +443,6 @@ class LatentsCachingStrategy: flip_aug: whether to flip images apply_alpha_mask: whether to apply alpha mask multi_resolution: whether to use multi-resolution latents - fallback_no_reso: fallback to legacy key without resolution suffix Returns: bool @@ -475,21 +460,16 @@ class LatentsCachingStrategy: key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" try: - with np.load(npz_path) as npz: - if "latents" + key_reso_suffix not in npz: - if not (multi_resolution and fallback_no_reso): - return False + npz = np.load(npz_path) - latents_shape = self._get_npz_array_shape(npz, "latents") - if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: - return False - - key_reso_suffix = "" - - if flip_aug and "latents_flipped" + key_reso_suffix not in npz: - return False - if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: - return False + # In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility) + # In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files). + if "latents" + key_reso_suffix not in npz and "latents" not in npz: + return False + if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz): + return False + if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz): + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -568,7 +548,7 @@ class LatentsCachingStrategy: self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ - for SD/SDXL + For single resolution architectures (currently no architecture is single resolution specific). Kept for reference. Args: npz_path (str): Path to the npz file. @@ -586,18 +566,13 @@ class LatentsCachingStrategy: return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( - self, - latents_stride: Optional[int], - npz_path: str, - bucket_reso: Tuple[int, int], - fallback_no_reso: bool = False, + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ Args: latents_stride (Optional[int]): Stride for latents. If None, load all latents. npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - fallback_no_reso (bool): fallback to legacy key without resolution suffix Returns: Tuple[ @@ -609,31 +584,30 @@ class LatentsCachingStrategy: ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ if latents_stride is None: - expected_latents_size = None key_reso_suffix = "" else: expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW - with np.load(npz_path) as npz: - key_reso_suffix = key_reso_suffix + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + # raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + # Fallback to old npz without resolution suffix + if "latents" not in npz: + raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})") + if not self._warned_fallback_to_old_npz: + logger.warning( + f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version." + ) + self._warned_fallback_to_old_npz = True + key_reso_suffix = "" - if "latents" + key_reso_suffix not in npz: - if not fallback_no_reso or expected_latents_size is None: - raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") - - latents_shape = self._get_npz_array_shape(npz, "latents") - if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: - raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}") - - key_reso_suffix = "" - - latents = npz["latents" + key_reso_suffix] - original_size = npz["original_size" + key_reso_suffix].tolist() - crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() - flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None - alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( self, diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 837b8f5a..4521ae8d 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -145,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): self.suffix = ( SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX ) - + @property def cache_suffix(self) -> str: return self.suffix @@ -158,25 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected( - 8, - bucket_reso, - npz_path, - flip_aug, - alpha_mask, - multi_resolution=True, - fallback_no_reso=True, - ) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - return self._default_load_latents_from_disk( - 8, - npz_path, - bucket_reso, - fallback_no_reso=True, - ) + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -184,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device)