feat: refactor latent cache format

This commit is contained in:
Kohya S
2024-11-15 21:16:49 +09:00
parent 0047bb1fc3
commit bdac55ebbc
17 changed files with 356 additions and 633 deletions

View File

@@ -177,7 +177,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -180,7 +180,7 @@ def main(args):
# バッチへ追加
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.latents_cache_path = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image

View File

@@ -198,7 +198,7 @@ def train(args):
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator)
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)

View File

@@ -2,9 +2,10 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from safetensors.torch import safe_open, save_file
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
@@ -12,6 +13,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library import utils
from library.utils import setup_logging
setup_logging()
@@ -20,6 +22,27 @@ import logging
logger = logging.getLogger(__name__)
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
if dtype is None:
# all dtypes are acceptable
return get_available_dtypes()
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
compatible_dtypes = [torch.float32]
if dtype.itemsize == 1: # fp8
compatible_dtypes.append(torch.bfloat16)
compatible_dtypes.append(torch.float16)
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
return compatible_dtypes
def get_available_dtypes() -> List[torch.dtype]:
"""
Returns the list of available dtypes for latents caching. Higher precision is preferred.
"""
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -382,11 +405,18 @@ class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
def __init__(
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
self._architecture = architecture
self._latents_stride = latents_stride
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self.load_version_warning_printed = False
self.save_version_warning_printed = False
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
@@ -397,6 +427,14 @@ class LatentsCachingStrategy:
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def latents_stride(self):
return self._latents_stride
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -407,69 +445,143 @@ class LatentsCachingStrategy:
@property
def cache_suffix(self):
raise NotImplementedError
return f"_{self.architecture.lower()}.safetensors"
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def get_latents_cache_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def get_key_suffix(
self,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
) -> str:
"""
if dtype is None, it returns "_32x64" for example.
"""
if latents_size is not None:
expected_latents_size = latents_size # H, W
else:
# bucket_reso is (W, H)
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
if dtype is None:
dtype_suffix = ""
else:
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
# e.g. "_32x64_float16", HxW, dtype
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
return key_suffix
def get_compatible_latents_keys(
self,
keys: set[str],
dtype: Union[str, torch.dtype],
flip_aug: bool,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
bucket_reso is (W, H), latents_size is (H, W)
"""
latents_key = None
flipped_latents_key = None
compatible_dtypes = get_compatible_dtypes(dtype)
for compat_dtype in compatible_dtypes:
key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype)
if latents_key is None:
latents_key = "latents" + key_suffix
if latents_key not in keys:
latents_key = None
if flip_aug and flipped_latents_key is None:
flipped_latents_key = "latents_flipped" + key_suffix
if flipped_latents_key not in keys:
flipped_latents_key = None
if latents_key is not None and (flipped_latents_key is not None or not flip_aug):
break
return latents_key, flipped_latents_key
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
latents_cache_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
preferred_dtype: Optional[Union[str, torch.dtype]],
):
# multi_resolution is always enabled for any strategy
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
if not os.path.exists(latents_cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
# safe_open locks the file, so we cannot use it for checking keys
# with safe_open(latents_cache_path, framework="pt") as f:
# keys = f.keys()
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
keys = f.keys()
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
# print(f"alpha_mask not found: {latents_cache_path}")
return False
if preferred_dtype is None:
# remove dtype suffix from keys, because any dtype is acceptable
keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)]
keys = set(keys)
if "latents" + key_suffix_without_dtype not in keys:
# print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}")
return False
if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys:
# print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}")
return False
else:
# specific dtype or compatible dtype is required
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.error(f"Error loading file: {latents_cache_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -499,13 +611,8 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -515,56 +622,111 @@ class LatentsCachingStrategy:
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
keys = f.keys()
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
latents = f.get_tensor(latents_key)
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
original_size = [int(metadata["width"]), int(metadata["height"])]
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
crop_ltrb = list(map(int, crop_ltrb.split(",")))
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
cache_path: str,
latents_tensor: torch.Tensor,
original_size: Tuple[int, int],
crop_ltrb: List[int],
flipped_latents_tensor: Optional[torch.Tensor] = None,
alpha_mask: Optional[torch.Tensor] = None,
):
kwargs = {}
dtype = latents_tensor.dtype
latents_size = latents_tensor.shape[1:3] # H, W
tensor_dict = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
# we cannot use safe_open here because it locks the file
# with safe_open(cache_path, framework="pt") as f:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["width"] = f"{original_size[0]}"
metadata["height"] = f"{original_size[1]}"
metadata["format_version"] = "1.0.0"
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if latents_tensor is not None:
tensor_dict["latents" + key_suffix] = latents_tensor
if flipped_latents_tensor is not None:
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
# remove lower precision latents if higher precision latents are already cached
if overwrite:
available_dtypes = get_available_dtypes()
available_itemsize = None
available_itemsize_flipped = None
for dtype in available_dtypes:
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if "latents" + key_suffix in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict["latents" + key_suffix]
if "latents_flipped" + key_suffix in tensor_dict:
if available_itemsize_flipped is None:
available_itemsize_flipped = dtype.itemsize
elif available_itemsize_flipped > dtype.itemsize:
del tensor_dict["latents_flipped" + key_suffix]
save_file(tensor_dict, cache_path, metadata=metadata)

View File

@@ -195,29 +195,25 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
ARCHITECTURE = "flux"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -225,9 +221,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -134,30 +134,28 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
ARCHITECTURE_SD = "sd"
ARCHITECTURE_SDXL = "sdxl"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
) -> bool:
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):

View File

@@ -382,29 +382,25 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
ARCHITECTURE_SD3 = "sd3"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -412,9 +408,7 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -158,11 +158,10 @@ class ImageInfo:
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
None # crop left top right bottom in original pixel size, not latents size
)
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
@@ -323,6 +322,9 @@ class BucketManager:
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
resized_width = int(resized_width + 0.5)
resized_height = int(resized_height + 0.5)
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
@@ -1040,7 +1042,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def new_cache_latents(self, model: Any, accelerator: Accelerator):
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool = False):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
@@ -1094,17 +1096,18 @@ class BaseDataset(torch.utils.data.Dataset):
try:
# iterate images
logger.info("caching latents...")
logger.info(f"Caching latents for dataset with {len(image_infos)} images.")
preferred_dtype = model.dtype if force_cache_precision else None
for i, info in enumerate(tqdm(image_infos)):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
if info.latents_cache_path is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size)
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
@@ -1114,7 +1117,7 @@ class BaseDataset(torch.utils.data.Dataset):
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
info.bucket_reso, info.latents_cache_path, subset.flip_aug, subset.alpha_mask, preferred_dtype
)
if cache_available: # do not add to batch
continue
@@ -1144,81 +1147,6 @@ class BaseDataset(torch.utils.data.Dataset):
finally:
executor.shutdown()
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution and some conditions
class Condition:
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
self.reso = reso
self.flip_aug = flip_aug
self.alpha_mask = alpha_mask
self.random_crop = random_crop
def __eq__(self, other):
return (
self.reso == other.reso
and self.flip_aug == other.flip_aug
and self.alpha_mask == other.alpha_mask
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append((current_condition, batch))
batch = []
current_condition = None
if len(batch) > 0:
batches.append((current_condition, batch))
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
@@ -1275,131 +1203,6 @@ class BaseDataset(torch.utils.data.Dataset):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
# to support SD1/2, it needs a flag for v2, but it is postponed
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True
):
assert len(tokenizers) == 2, "only support SDXL"
return self.cache_text_encoder_outputs_common(
tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process
)
# same as above, but for SD3
def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
):
return self.cache_text_encoder_outputs_common(
[tokenizer],
text_encoders,
devices,
output_dtype,
te_dtypes,
cache_to_disk,
is_main_process,
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
batch_size,
)
def cache_text_encoder_outputs_common(
self,
tokenizers,
text_encoders,
devices,
output_dtype,
te_dtypes,
cache_to_disk=False,
is_main_process=True,
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
batch_size=None,
):
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
tokenize_strategy = TokenizeStrategy.get_strategy()
if batch_size is None:
batch_size = self.batch_size
image_infos = list(self.image_data.values())
logger.info("checking cache existence...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
# TODO check varidity of cache here
continue
image_infos_to_cache.append(info)
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# prepare tokenizers and text encoders
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
text_encoder.to(device)
if te_dtype is not None:
text_encoder.to(dtype=te_dtype)
# create batch
is_sd3 = len(tokenizers) == 1
batch = []
batches = []
for info in image_infos_to_cache:
if not is_sd3:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
batch.append((info, l_tokens, g_tokens, t5_tokens))
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
if not is_sd3:
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype
)
else:
for batch in tqdm(batches):
infos, l_tokens, g_tokens, t5_tokens = zip(*batch)
# stack tokens
# l_tokens = [tokens[0] for tokens in l_tokens]
# g_tokens = [tokens[0] for tokens in g_tokens]
# t5_tokens = [tokens[0] for tokens in t5_tokens]
cache_batch_text_encoder_outputs_sd3(
infos,
tokenizers[0],
text_encoders,
self.max_token_length,
cache_to_disk,
(l_tokens, g_tokens, t5_tokens),
output_dtype,
)
def get_image_size(self, image_path):
# return imagesize.get(image_path)
image_size = imagesize.get(image_path)
@@ -1522,17 +1325,14 @@ class BaseDataset(torch.utils.data.Dataset):
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
elif image_info.latents_cache_path is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_cache_path, image_info.bucket_reso)
)
if flipped:
latents = flipped_latents
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1]
del flipped_latents
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
image = None
else:
@@ -1885,28 +1685,28 @@ class DreamBoothDataset(BaseDataset):
if strategy is not None:
logger.info("get image size from name of cache files")
# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort(
# make image path to cache path mapping
cache_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
cache_paths.sort(
key=lambda item: item.rsplit("_", maxsplit=2)[0]
) # sort by name excluding resolution and cache_suffix
npz_path_index = 0
cache_path_index = 0
size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)):
l = len(os.path.splitext(img_path)[0]) # remove extension
found = False
while npz_path_index < len(npz_paths): # until found or end of npz_paths
while cache_path_index < len(cache_paths): # until found or end of npz_paths
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
if npz_paths[npz_path_index][:l] > img_path[:l]:
if cache_paths[cache_path_index][:l] > img_path[:l]:
break
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
if cache_paths[cache_path_index][:l] == img_path[:l]: # found
found = True
break
npz_path_index += 1 # next npz_path
cache_path_index += 1 # next npz_path
if found:
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
w, h = strategy.get_image_size_from_disk_cache_path(img_path, cache_paths[cache_path_index])
else:
w, h = None, None
@@ -2139,8 +1939,8 @@ class FineTuningDataset(BaseDataset):
image_info.image_size = img_md.get("train_resolution")
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
# if cache exists, use them
image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
@@ -2161,7 +1961,7 @@ class FineTuningDataset(BaseDataset):
for image_info in self.image_data.values():
subset = self.image_to_subset[image_info.image_key]
has_npz = image_info.latents_npz is not None
has_npz = image_info.latents_cache_path is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
@@ -2233,7 +2033,7 @@ class FineTuningDataset(BaseDataset):
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
image_info.latents_cache_path = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
@@ -2382,11 +2182,8 @@ class ControlNetDataset(BaseDataset):
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def new_cache_latents(self, model: Any, accelerator: Accelerator):
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator)
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool):
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator, force_cache_precision)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
@@ -2485,33 +2282,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, model: Any, accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, accelerator)
dataset.new_cache_latents(model, accelerator, force_cache_precision)
accelerator.wait_for_everyone()
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs_sd3(
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
)
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
@@ -2556,72 +2332,6 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
try:
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
# TODO update to use CachingStrategy
# def load_latents_from_disk(
# npz_path,
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
# npz = np.load(npz_path)
# if "latents" not in npz:
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
# latents = npz["latents"]
# original_size = npz["original_size"].tolist()
# crop_ltrb = npz["crop_ltrb"].tolist()
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
# kwargs = {}
# if flipped_latents_tensor is not None:
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
# if alpha_mask is not None:
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
# np.savez(
# npz_path,
# latents=latents_tensor.float().cpu().numpy(),
# original_size=np.array(original_size),
# crop_ltrb=np.array(crop_ltrb),
# **kwargs,
# )
def debug_dataset(train_dataset, show_input_ids=False):
logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
logger.info(
@@ -2865,19 +2575,19 @@ def trim_and_resize_if_required(
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
) -> Tuple[torch.Tensor, list[Optional[torch.Tensor]], list[Tuple[int, int]], list[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
alpha_masks: List[torch.Tensor] = [torch.Size([H, W]), ...], List of None if not use_alpha_mask
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
"""
images: List[torch.Tensor] = []
alpha_masks: List[np.ndarray] = []
alpha_masks: List[torch.Tensor] = []
original_sizes: List[Tuple[int, int]] = []
crop_ltrbs: List[Tuple[int, int, int, int]] = []
for info in image_infos:
@@ -2907,158 +2617,6 @@ def load_images_and_masks_for_caching(
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_ltrb are also set
"""
images = []
alpha_masks: List[np.ndarray] = []
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
# save_latents_to_disk(
# info.latents_npz,
# latent,
# info.latents_original_size,
# info.latents_crop_ltrb,
# flipped_latent,
# alpha_mask,
# )
pass
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
if not HIGH_VRAM:
clean_memory_on_device(vae.device)
def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# ここでcpuに移動しておかないと、上書きされてしまう
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
b_pool2 = b_pool2.detach().to("cpu") # b,1280
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
else:
info.text_encoder_outputs1 = hidden_state1
info.text_encoder_outputs2 = hidden_state2
info.text_encoder_pool2 = pool2
def cache_batch_text_encoder_outputs_sd3(
image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype
):
# make input_ids for each text encoder
l_tokens, g_tokens, t5_tokens = input_ids
clip_l, clip_g, t5xxl = text_encoders
with torch.no_grad():
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype
)
b_lg_out = b_lg_out.detach()
b_t5_out = b_t5_out.detach()
b_pool = b_pool.detach()
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
# debug: NaN check
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
else:
info.text_encoder_outputs1 = lg_out
info.text_encoder_outputs2 = t5_out
info.text_encoder_pool2 = pool
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy(),
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
def load_text_encoder_outputs_from_disk(npz_path):
with np.load(npz_path) as f:
hidden_state1 = torch.from_numpy(f["hidden_state1"])
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
return hidden_state1, hidden_state2, pool2
# endregion
# region モジュール入れ替え部
@@ -4357,6 +3915,12 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--force_cache_precision",
action="store_true",
help="force cache precision to match the model precision. this option re-caches latents if the precision is lower than the model precision"
" / cacheの精度をモデルの精度に合わせる。このオプションを指定すると、精度がモデルの精度よりも低い場合にlatentを再キャッシュします",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
@@ -5913,7 +5477,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")
names.append("text_encoder3") # SD3
names.append("text_encoder3") # SD3
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)

View File

@@ -189,6 +189,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
raise ValueError(f"Unsupported dtype: {s}")
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
# get name of the dtype
dtype_name = str(dtype).split(".")[-1]
return dtype_name
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
@@ -264,8 +273,8 @@ class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
@@ -276,6 +285,9 @@ class MemoryEfficientSafeOpen:
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
@@ -293,10 +305,9 @@ class MemoryEfficientSafeOpen:
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])

View File

@@ -330,7 +330,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)

View File

@@ -272,7 +272,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -209,7 +209,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -181,7 +181,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -149,7 +149,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents to disk.")

View File

@@ -156,7 +156,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -418,7 +418,7 @@ class NetworkTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -378,7 +378,7 @@ class TextualInversionTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()