mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Multi-resolution dataset for SD1/SDXL (#2269)
* Multi-resolution dataset for SD1/SDXL * Add fallback to legacy key without resolution suffix * Support numpy 2.2
This commit is contained in:
@@ -6,6 +6,11 @@ from typing import Any, List, Optional, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from numpy.lib import _format_impl as np_format_impl
|
||||
except ImportError:
|
||||
from numpy.lib import format as np_format_impl
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
@@ -424,6 +429,16 @@ class LatentsCachingStrategy:
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]:
|
||||
"""Get array shape in npz file by only reading the header."""
|
||||
if key not in npz:
|
||||
return None
|
||||
|
||||
with npz.zip.open(key + ".npy") as npy_file:
|
||||
version = np.lib.format.read_magic(npy_file)
|
||||
shape, _, _ = np_format_impl._read_array_header(npy_file, version)
|
||||
return shape
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
latents_stride: int,
|
||||
@@ -432,6 +447,7 @@ class LatentsCachingStrategy:
|
||||
flip_aug: bool,
|
||||
apply_alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
fallback_no_reso: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
@@ -441,6 +457,7 @@ class LatentsCachingStrategy:
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
fallback_no_reso: fallback to legacy key without resolution suffix
|
||||
|
||||
Returns:
|
||||
bool
|
||||
@@ -458,13 +475,21 @@ class LatentsCachingStrategy:
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
return False
|
||||
with np.load(npz_path) as npz:
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
if not (multi_resolution and fallback_no_reso):
|
||||
return False
|
||||
|
||||
latents_shape = self._get_npz_array_shape(npz, "latents")
|
||||
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
|
||||
return False
|
||||
|
||||
key_reso_suffix = ""
|
||||
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -495,8 +520,8 @@ class LatentsCachingStrategy:
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
@@ -548,52 +573,67 @@ class LatentsCachingStrategy:
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
self,
|
||||
latents_stride: Optional[int],
|
||||
npz_path: str,
|
||||
bucket_reso: Tuple[int, int],
|
||||
fallback_no_reso: bool = False,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Args:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
fallback_no_reso (bool): fallback to legacy key without resolution suffix
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
expected_latents_size = None
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
with np.load(npz_path) as npz:
|
||||
key_reso_suffix = key_reso_suffix
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
if not fallback_no_reso or expected_latents_size is None:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
|
||||
latents_shape = self._get_npz_array_shape(npz, "latents")
|
||||
if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size:
|
||||
raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}")
|
||||
|
||||
key_reso_suffix = ""
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
|
||||
@@ -2,6 +2,7 @@ import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
@@ -157,7 +158,25 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8,
|
||||
bucket_reso,
|
||||
npz_path,
|
||||
flip_aug,
|
||||
alpha_mask,
|
||||
multi_resolution=True,
|
||||
fallback_no_reso=True,
|
||||
)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(
|
||||
8,
|
||||
npz_path,
|
||||
bucket_reso,
|
||||
fallback_no_reso=True,
|
||||
)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -165,7 +184,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
Reference in New Issue
Block a user