feat: backward compatibility for SD/SDXL latent cache (#2276)

* fix: improve handling of legacy npz files and add logging for fallback scenarios

* fix: simplify fallback handling in SdSdxlLatentsCachingStrategy
This commit is contained in:
Kohya S.
2026-02-23 21:44:51 +09:00
committed by GitHub
parent 98a42e4cd6
commit f90fa1a89a
2 changed files with 37 additions and 74 deletions

View File

@@ -6,11 +6,6 @@ from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np import numpy as np
import torch import torch
try:
from numpy.lib import _format_impl as np_format_impl
except ImportError:
from numpy.lib import format as np_format_impl
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
@@ -387,6 +382,8 @@ class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class _strategy = None # strategy instance: actual strategy class
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk self._cache_to_disk = cache_to_disk
self._batch_size = batch_size self._batch_size = batch_size
@@ -429,16 +426,6 @@ class LatentsCachingStrategy:
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError raise NotImplementedError
def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]:
"""Get array shape in npz file by only reading the header."""
if key not in npz:
return None
with npz.zip.open(key + ".npy") as npy_file:
version = np.lib.format.read_magic(npy_file)
shape, _, _ = np_format_impl._read_array_header(npy_file, version)
return shape
def _default_is_disk_cached_latents_expected( def _default_is_disk_cached_latents_expected(
self, self,
latents_stride: int, latents_stride: int,
@@ -447,7 +434,6 @@ class LatentsCachingStrategy:
flip_aug: bool, flip_aug: bool,
apply_alpha_mask: bool, apply_alpha_mask: bool,
multi_resolution: bool = False, multi_resolution: bool = False,
fallback_no_reso: bool = False,
) -> bool: ) -> bool:
""" """
Args: Args:
@@ -457,7 +443,6 @@ class LatentsCachingStrategy:
flip_aug: whether to flip images flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents multi_resolution: whether to use multi-resolution latents
fallback_no_reso: fallback to legacy key without resolution suffix
Returns: Returns:
bool bool
@@ -475,21 +460,16 @@ class LatentsCachingStrategy:
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try: try:
with np.load(npz_path) as npz: npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
if not (multi_resolution and fallback_no_reso):
return False
latents_shape = self._get_npz_array_shape(npz, "latents") # In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: # In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
return False if "latents" + key_reso_suffix not in npz and "latents" not in npz:
return False
key_reso_suffix = "" if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz: if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
return False return False
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -568,7 +548,7 @@ class LatentsCachingStrategy:
self, npz_path: str, bucket_reso: Tuple[int, int] self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
""" """
for SD/SDXL For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
Args: Args:
npz_path (str): Path to the npz file. npz_path (str): Path to the npz file.
@@ -586,18 +566,13 @@ class LatentsCachingStrategy:
return self._default_load_latents_from_disk(None, npz_path, bucket_reso) return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk( def _default_load_latents_from_disk(
self, self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
latents_stride: Optional[int],
npz_path: str,
bucket_reso: Tuple[int, int],
fallback_no_reso: bool = False,
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
""" """
Args: Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents. latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file. npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket. bucket_reso (Tuple[int, int]): The resolution of the bucket.
fallback_no_reso (bool): fallback to legacy key without resolution suffix
Returns: Returns:
Tuple[ Tuple[
@@ -609,31 +584,30 @@ class LatentsCachingStrategy:
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
""" """
if latents_stride is None: if latents_stride is None:
expected_latents_size = None
key_reso_suffix = "" key_reso_suffix = ""
else: else:
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
with np.load(npz_path) as npz: npz = np.load(npz_path)
key_reso_suffix = key_reso_suffix if "latents" + key_reso_suffix not in npz:
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
# Fallback to old npz without resolution suffix
if "latents" not in npz:
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
if not self._warned_fallback_to_old_npz:
logger.warning(
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
)
self._warned_fallback_to_old_npz = True
key_reso_suffix = ""
if "latents" + key_reso_suffix not in npz: latents = npz["latents" + key_reso_suffix]
if not fallback_no_reso or expected_latents_size is None: original_size = npz["original_size" + key_reso_suffix].tolist()
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
latents_shape = self._get_npz_array_shape(npz, "latents") alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}")
key_reso_suffix = ""
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk( def save_latents_to_disk(
self, self,

View File

@@ -145,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
self.suffix = ( self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
) )
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return self.suffix return self.suffix
@@ -158,25 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected( return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
8,
bucket_reso,
npz_path,
flip_aug,
alpha_mask,
multi_resolution=True,
fallback_no_reso=True,
)
def load_latents_from_disk( def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int] self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk( return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
8,
npz_path,
bucket_reso,
fallback_no_reso=True,
)
# TODO remove circular dependency for ImageInfo # TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -184,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device vae_device = vae.device
vae_dtype = vae.dtype vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM: if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device) train_util.clean_memory_on_device(vae.device)