feat: backward compatibility for SD/SDXL latent cache (#2276)

* fix: improve handling of legacy npz files and add logging for fallback scenarios

* fix: simplify fallback handling in SdSdxlLatentsCachingStrategy
This commit is contained in:
Kohya S.
2026-02-23 21:44:51 +09:00
committed by GitHub
parent 98a42e4cd6
commit f90fa1a89a
2 changed files with 37 additions and 74 deletions

View File

@@ -6,11 +6,6 @@ from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
try:
from numpy.lib import _format_impl as np_format_impl
except ImportError:
from numpy.lib import format as np_format_impl
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
@@ -387,6 +382,8 @@ class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
@@ -429,16 +426,6 @@ class LatentsCachingStrategy:
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]:
"""Get array shape in npz file by only reading the header."""
if key not in npz:
return None
with npz.zip.open(key + ".npy") as npy_file:
version = np.lib.format.read_magic(npy_file)
shape, _, _ = np_format_impl._read_array_header(npy_file, version)
return shape
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
@@ -447,7 +434,6 @@ class LatentsCachingStrategy:
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
fallback_no_reso: bool = False,
) -> bool:
"""
Args:
@@ -457,7 +443,6 @@ class LatentsCachingStrategy:
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
fallback_no_reso: fallback to legacy key without resolution suffix
Returns:
bool
@@ -475,21 +460,16 @@ class LatentsCachingStrategy:
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
with np.load(npz_path) as npz:
if "latents" + key_reso_suffix not in npz:
if not (multi_resolution and fallback_no_reso):
return False
npz = np.load(npz_path)
latents_shape = self._get_npz_array_shape(npz, "latents")
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
return False
key_reso_suffix = ""
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
# In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
# In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
if "latents" + key_reso_suffix not in npz and "latents" not in npz:
return False
if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
return False
if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -568,7 +548,7 @@ class LatentsCachingStrategy:
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
Args:
npz_path (str): Path to the npz file.
@@ -586,18 +566,13 @@ class LatentsCachingStrategy:
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self,
latents_stride: Optional[int],
npz_path: str,
bucket_reso: Tuple[int, int],
fallback_no_reso: bool = False,
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
fallback_no_reso (bool): fallback to legacy key without resolution suffix
Returns:
Tuple[
@@ -609,31 +584,30 @@ class LatentsCachingStrategy:
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if latents_stride is None:
expected_latents_size = None
key_reso_suffix = ""
else:
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
with np.load(npz_path) as npz:
key_reso_suffix = key_reso_suffix
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
# Fallback to old npz without resolution suffix
if "latents" not in npz:
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
if not self._warned_fallback_to_old_npz:
logger.warning(
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
)
self._warned_fallback_to_old_npz = True
key_reso_suffix = ""
if "latents" + key_reso_suffix not in npz:
if not fallback_no_reso or expected_latents_size is None:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents_shape = self._get_npz_array_shape(npz, "latents")
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}")
key_reso_suffix = ""
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,

View File

@@ -145,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
@@ -158,25 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(
8,
bucket_reso,
npz_path,
flip_aug,
alpha_mask,
multi_resolution=True,
fallback_no_reso=True,
)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(
8,
npz_path,
bucket_reso,
fallback_no_reso=True,
)
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -184,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)