diff --git a/library/strategy_base.py b/library/strategy_base.py index 6e6487ea..9a2acdba 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -6,6 +6,11 @@ from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch + +try: + from numpy.lib import _format_impl as np_format_impl +except ImportError: + from numpy.lib import format as np_format_impl from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection @@ -424,6 +429,16 @@ class LatentsCachingStrategy: def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError + def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]: + """Get array shape in npz file by only reading the header.""" + if key not in npz: + return None + + with npz.zip.open(key + ".npy") as npy_file: + version = np.lib.format.read_magic(npy_file) + shape, _, _ = np_format_impl._read_array_header(npy_file, version) + return shape + def _default_is_disk_cached_latents_expected( self, latents_stride: int, @@ -432,6 +447,7 @@ class LatentsCachingStrategy: flip_aug: bool, apply_alpha_mask: bool, multi_resolution: bool = False, + fallback_no_reso: bool = False, ) -> bool: """ Args: @@ -441,6 +457,7 @@ class LatentsCachingStrategy: flip_aug: whether to flip images apply_alpha_mask: whether to apply alpha mask multi_resolution: whether to use multi-resolution latents + fallback_no_reso: fallback to legacy key without resolution suffix Returns: bool @@ -458,13 +475,21 @@ class LatentsCachingStrategy: key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" try: - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - return False - if flip_aug and "latents_flipped" + key_reso_suffix not in npz: - return False - if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: - return False + with np.load(npz_path) as npz: + if "latents" + key_reso_suffix not in npz: + if not (multi_resolution and fallback_no_reso): + return False + + latents_shape = self._get_npz_array_shape(npz, "latents") + if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: + return False + + key_reso_suffix = "" + + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -495,8 +520,8 @@ class LatentsCachingStrategy: apply_alpha_mask: whether to apply alpha mask random_crop: whether to random crop images multi_resolution: whether to use multi-resolution latents - - Returns: + + Returns: None """ from library import train_util # import here to avoid circular import @@ -548,52 +573,67 @@ class LatentsCachingStrategy: Args: npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - + Returns: Tuple[ - Optional[np.ndarray], - Optional[List[int]], - Optional[List[int]], - Optional[np.ndarray], + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], Optional[np.ndarray] ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( - self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + self, + latents_stride: Optional[int], + npz_path: str, + bucket_reso: Tuple[int, int], + fallback_no_reso: bool = False, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ Args: latents_stride (Optional[int]): Stride for latents. If None, load all latents. npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - + fallback_no_reso (bool): fallback to legacy key without resolution suffix + Returns: Tuple[ - Optional[np.ndarray], - Optional[List[int]], - Optional[List[int]], - Optional[np.ndarray], + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], Optional[np.ndarray] ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ if latents_stride is None: + expected_latents_size = None key_reso_suffix = "" else: - latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) - key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + with np.load(npz_path) as npz: + key_reso_suffix = key_reso_suffix - latents = npz["latents" + key_reso_suffix] - original_size = npz["original_size" + key_reso_suffix].tolist() - crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() - flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None - alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + if "latents" + key_reso_suffix not in npz: + if not fallback_no_reso or expected_latents_size is None: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents_shape = self._get_npz_array_shape(npz, "latents") + if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: + raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}") + + key_reso_suffix = "" + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( self, diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc409..837b8f5a 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -2,6 +2,7 @@ import glob import os from typing import Any, List, Optional, Tuple, Union +import numpy as np import torch from transformers import CLIPTokenizer from library import train_util @@ -157,7 +158,25 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected( + 8, + bucket_reso, + npz_path, + flip_aug, + alpha_mask, + multi_resolution=True, + fallback_no_reso=True, + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk( + 8, + npz_path, + bucket_reso, + fallback_no_reso=True, + ) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -165,7 +184,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device)