Multi-resolution dataset for SD1/SDXL (#2269)

* Multi-resolution dataset for SD1/SDXL

* Add fallback to legacy key without resolution suffix

* Support numpy 2.2
This commit is contained in:
woctordho
2026-02-23 14:30:36 +08:00
committed by GitHub
parent 609d1292f6
commit 50694df3cf
2 changed files with 92 additions and 33 deletions

View File

@@ -6,6 +6,11 @@ from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
try:
from numpy.lib import _format_impl as np_format_impl
except ImportError:
from numpy.lib import format as np_format_impl
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
@@ -424,6 +429,16 @@ class LatentsCachingStrategy:
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]:
"""Get array shape in npz file by only reading the header."""
if key not in npz:
return None
with npz.zip.open(key + ".npy") as npy_file:
version = np.lib.format.read_magic(npy_file)
shape, _, _ = np_format_impl._read_array_header(npy_file, version)
return shape
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
@@ -432,6 +447,7 @@ class LatentsCachingStrategy:
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
fallback_no_reso: bool = False,
) -> bool:
"""
Args:
@@ -441,6 +457,7 @@ class LatentsCachingStrategy:
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
fallback_no_reso: fallback to legacy key without resolution suffix
Returns:
bool
@@ -458,13 +475,21 @@ class LatentsCachingStrategy:
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
with np.load(npz_path) as npz:
if "latents" + key_reso_suffix not in npz:
if not (multi_resolution and fallback_no_reso):
return False
latents_shape = self._get_npz_array_shape(npz, "latents")
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
return False
key_reso_suffix = ""
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -495,8 +520,8 @@ class LatentsCachingStrategy:
apply_alpha_mask: whether to apply alpha mask
random_crop: whether to random crop images
multi_resolution: whether to use multi-resolution latents
Returns:
Returns:
None
"""
from library import train_util # import here to avoid circular import
@@ -548,52 +573,67 @@ class LatentsCachingStrategy:
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
self,
latents_stride: Optional[int],
npz_path: str,
bucket_reso: Tuple[int, int],
fallback_no_reso: bool = False,
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
fallback_no_reso (bool): fallback to legacy key without resolution suffix
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if latents_stride is None:
expected_latents_size = None
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
with np.load(npz_path) as npz:
key_reso_suffix = key_reso_suffix
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
if "latents" + key_reso_suffix not in npz:
if not fallback_no_reso or expected_latents_size is None:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents_shape = self._get_npz_array_shape(npz, "latents")
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}")
key_reso_suffix = ""
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,