Compare commits

...

1 Commits

Author SHA1 Message Date
Kohya S
a437949d47 feat: Add support for Safetensors format in caching strategies (WIP)
- Introduced Safetensors output format for various caching strategies including Hunyuan, Lumina, SD, SDXL, and SD3.
- Updated methods to handle loading and saving of tensors in Safetensors format.
- Enhanced output validation to check for required tensors in both NPZ and Safetensors formats.
- Modified dataset argument parser to include `--cache_format` option for selecting between NPZ and Safetensors formats.
- Updated caching logic to accommodate partial loading and merging of existing Safetensors files.
2026-03-22 21:15:12 +09:00
11 changed files with 932 additions and 267 deletions

View File

@@ -155,6 +155,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
""" """
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz" ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
def __init__( def __init__(
self, self,
@@ -166,7 +167,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk: if not self.cache_to_disk:
@@ -177,17 +179,34 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "prompt_embeds" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
if "attn_mask" not in npz:
return False with MemoryEfficientSafeOpen(npz_path) as f:
if "t5_input_ids" not in npz: keys = f.keys()
return False if not _find_tensor_by_prefix(keys, "prompt_embeds"):
if "t5_attn_mask" not in npz: return False
return False if "attn_mask" not in keys:
if "caption_dropout_rate" not in npz: return False
return False if "t5_input_ids" not in keys:
return False
if "t5_attn_mask" not in keys:
return False
if "caption_dropout_rate" not in keys:
return False
else:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -195,6 +214,19 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
attn_mask = f.get_tensor("attn_mask").numpy()
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
data = np.load(npz_path) data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"] prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"] attn_mask = data["attn_mask"]
@@ -219,32 +251,75 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, tokens_and_masks tokenize_strategy, models, tokens_and_masks
) )
# Convert to numpy for caching if self.cache_format == "safetensors":
if prompt_embeds.dtype == torch.bfloat16: self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
prompt_embeds = prompt_embeds.float() else:
prompt_embeds = prompt_embeds.cpu().numpy() # Convert to numpy for caching
attn_mask = attn_mask.cpu().numpy() if prompt_embeds.dtype == torch.bfloat16:
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32) prompt_embeds = prompt_embeds.float()
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32) prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
prompt_embeds = prompt_embeds.cpu()
attn_mask = attn_mask.cpu()
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
for i, info in enumerate(infos): for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk: if self.cache_to_disk:
np.savez( tensors = {}
info.text_encoder_outputs_npz, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
prompt_embeds=prompt_embeds_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
attn_mask=attn_mask_i, for key in f.keys():
t5_input_ids=t5_input_ids_i, tensors[key] = f.get_tensor(key)
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate, pe = prompt_embeds[i]
) tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
tensors["attn_mask"] = attn_mask[i]
tensors["t5_input_ids"] = t5_input_ids[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
metadata = {
"architecture": "anima",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate) caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
info.text_encoder_outputs = (
prompt_embeds[i].numpy(),
attn_mask[i].numpy(),
t5_input_ids[i].numpy(),
t5_attn_mask[i].numpy(),
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy): class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -255,16 +330,20 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
""" """
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz" ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_NPZ_SUFFIX return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "anima"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -2,7 +2,7 @@
import os import os
import re import re
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import numpy as np import numpy as np
import torch import torch
@@ -19,6 +19,48 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
# global cache format setting: "npz" or "safetensors"
_cache_format: str = "npz"
def set_cache_format(cache_format: str) -> None:
global _cache_format
_cache_format = cache_format
def get_cache_format() -> str:
return _cache_format
_TORCH_DTYPE_TO_STR = {
torch.float64: "float64",
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint8: "uint8",
torch.bool: "bool",
}
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
def _dtype_to_str(dtype: torch.dtype) -> str:
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
for key in tensors_keys:
if key.startswith(prefix) or key == prefix:
return key
return None
class TokenizeStrategy: class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class _strategy = None # strategy instance: actual strategy class
@@ -362,6 +404,10 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self): def is_weighted(self):
return self._is_weighted return self._is_weighted
@property
def cache_format(self) -> str:
return get_cache_format()
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError raise NotImplementedError
@@ -407,6 +453,10 @@ class LatentsCachingStrategy:
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@property
def cache_format(self) -> str:
return get_cache_format()
@property @property
def cache_suffix(self): def cache_suffix(self):
raise NotImplementedError raise NotImplementedError
@@ -439,7 +489,7 @@ class LatentsCachingStrategy:
Args: Args:
latents_stride: stride of latents latents_stride: stride of latents
bucket_reso: resolution of the bucket bucket_reso: resolution of the bucket
npz_path: path to the npz file npz_path: path to the npz/safetensors file
flip_aug: whether to flip images flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents multi_resolution: whether to use multi-resolution latents
@@ -454,6 +504,11 @@ class LatentsCachingStrategy:
if self.skip_disk_cache_validity_check: if self.skip_disk_cache_validity_check:
return True return True
if npz_path.endswith(".safetensors"):
return self._is_disk_cached_latents_expected_safetensors(
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW # e.g. "_32x64", HxW
@@ -476,6 +531,40 @@ class LatentsCachingStrategy:
return True return True
def _is_disk_cached_latents_expected_safetensors(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
st_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
from library.safetensors_utils import MemoryEfficientSafeOpen
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
try:
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_prefix = f"latents_{reso_tag}"
if not any(k.startswith(latents_prefix) for k in keys):
return False
if flip_aug:
flipped_prefix = f"latents_flipped_{reso_tag}"
if not any(k.startswith(flipped_prefix) for k in keys):
return False
if apply_alpha_mask:
mask_prefix = f"alpha_mask_{reso_tag}"
if not any(k.startswith(mask_prefix) for k in keys):
return False
except Exception as e:
logger.error(f"Error loading file: {st_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo # TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents( def _default_cache_batch_latents(
self, self,
@@ -571,7 +660,7 @@ class LatentsCachingStrategy:
""" """
Args: Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents. latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file. npz_path (str): Path to the npz/safetensors file.
bucket_reso (Tuple[int, int]): The resolution of the bucket. bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns: Returns:
@@ -583,6 +672,9 @@ class LatentsCachingStrategy:
Optional[np.ndarray] Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
""" """
if npz_path.endswith(".safetensors"):
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
if latents_stride is None: if latents_stride is None:
key_reso_suffix = "" key_reso_suffix = ""
else: else:
@@ -609,6 +701,39 @@ class LatentsCachingStrategy:
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def _load_latents_from_disk_safetensors(
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
from library.safetensors_utils import MemoryEfficientSafeOpen
if latents_stride is None:
reso_tag = "1x"
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
if latents_key is None:
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
latents = f.get_tensor(latents_key).numpy()
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk( def save_latents_to_disk(
self, self,
npz_path, npz_path,
@@ -621,17 +746,23 @@ class LatentsCachingStrategy:
): ):
""" """
Args: Args:
npz_path (str): Path to the npz file. npz_path (str): Path to the npz/safetensors file.
latents_tensor (torch.Tensor): Latent tensor latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
Returns: Returns:
None None
""" """
if npz_path.endswith(".safetensors"):
self._save_latents_to_disk_safetensors(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
)
return
kwargs = {} kwargs = {}
if os.path.exists(npz_path): if os.path.exists(npz_path):
@@ -640,7 +771,7 @@ class LatentsCachingStrategy:
for key in npz.files: for key in npz.files:
kwargs[key] = npz[key] kwargs[key] = npz[key]
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16. # float() is needed because npz doesn't support bfloat16
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size) kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
@@ -649,3 +780,59 @@ class LatentsCachingStrategy:
if alpha_mask is not None: if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs) np.savez(npz_path, **kwargs)
def _save_latents_to_disk_safetensors(
self,
st_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
latents_tensor = latents_tensor.cpu()
latents_size = latents_tensor.shape[-2:] # H, W
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
dtype_str = _dtype_to_str(latents_tensor.dtype)
# NaN check and zero replacement
if torch.isnan(latents_tensor).any():
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
tensors: Dict[str, torch.Tensor] = {}
# load existing file and merge (for multi-resolution)
if os.path.exists(st_path):
with MemoryEfficientSafeOpen(st_path) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
if flipped_latents_tensor is not None:
flipped_latents_tensor = flipped_latents_tensor.cpu()
if torch.isnan(flipped_latents_tensor).any():
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
if alpha_mask is not None:
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
metadata = {
"architecture": self._get_architecture_name(),
"width": str(latents_size[1]),
"height": str(latents_size[0]),
"format_version": LATENTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, st_path, metadata=metadata)
def _get_architecture_name(self) -> str:
"""Override in subclasses to return the architecture name for safetensors metadata."""
return "unknown"

View File

@@ -87,6 +87,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors"
def __init__( def __init__(
self, self,
@@ -102,7 +103,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = False self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str): def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk: if not self.cache_to_disk:
@@ -113,20 +115,40 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "l_pooled" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
if "t5_out" not in npz:
return False with MemoryEfficientSafeOpen(npz_path) as f:
if "txt_ids" not in npz: keys = f.keys()
return False if not _find_tensor_by_prefix(keys, "l_pooled"):
if "t5_attn_mask" not in npz: return False
return False if not _find_tensor_by_prefix(keys, "t5_out"):
if "apply_t5_attn_mask" not in npz: return False
return False if not _find_tensor_by_prefix(keys, "txt_ids"):
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] return False
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: if "t5_attn_mask" not in keys:
return False return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -134,6 +156,18 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
data = np.load(npz_path) data = np.load(npz_path)
l_pooled = data["l_pooled"] l_pooled = data["l_pooled"]
t5_out = data["t5_out"] t5_out = data["t5_out"]
@@ -161,56 +195,100 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16: t5_attn_mask_tokens = tokens_and_masks[2]
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy() if self.cache_format == "safetensors":
t5_out = t5_out.cpu().numpy() self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos)
txt_ids = txt_ids.cpu().numpy() else:
t5_attn_mask = tokens_and_masks[2].cpu().numpy() if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos): for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk: if self.cache_to_disk:
np.savez( tensors = {}
info.text_encoder_outputs_npz, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
l_pooled=l_pooled_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
t5_out=t5_out_i, for key in f.keys():
txt_ids=txt_ids_i, tensors[key] = f.get_tensor(key)
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i, lp = l_pooled[i]
) to = t5_out[i]
ti = txt_ids[i]
tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp
tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to
tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "flux",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy())
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy): class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return (
os.path.splitext(absolute_path)[0] os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}" + f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + self.cache_suffix
) )
def _get_architecture_name(self) -> str:
return "flux"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -81,16 +81,17 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz" HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors"
def __init__( def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False,
) -> None: ) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return ( return (
os.path.splitext(image_abs_path)[0] os.path.splitext(image_abs_path)[0] + suffix
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
) )
def is_disk_cached_outputs_expected(self, npz_path: str): def is_disk_cached_outputs_expected(self, npz_path: str):
@@ -102,17 +103,34 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "vlm_embed" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
if "vlm_mask" not in npz:
return False with MemoryEfficientSafeOpen(npz_path) as f:
if "byt5_embed" not in npz: keys = f.keys()
return False if not _find_tensor_by_prefix(keys, "vlm_embed"):
if "byt5_mask" not in npz: return False
return False if "vlm_mask" not in keys:
if "ocr_mask" not in npz: return False
return False if not _find_tensor_by_prefix(keys, "byt5_embed"):
return False
if "byt5_mask" not in keys:
return False
if "ocr_mask" not in keys:
return False
else:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -120,6 +138,19 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy()
vlm_mask = f.get_tensor("vlm_mask").numpy()
byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy()
byt5_mask = f.get_tensor("byt5_mask").numpy()
ocr_mask = f.get_tensor("ocr_mask").numpy()
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
data = np.load(npz_path) data = np.load(npz_path)
vln_embed = data["vlm_embed"] vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"] vlm_mask = data["vlm_mask"]
@@ -140,54 +171,102 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
tokenize_strategy, models, tokens_and_masks tokenize_strategy, models, tokens_and_masks
) )
if vlm_embed.dtype == torch.bfloat16: if self.cache_format == "safetensors":
vlm_embed = vlm_embed.float() self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos)
if byt5_embed.dtype == torch.bfloat16: else:
byt5_embed = byt5_embed.float() if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
vlm_embed = vlm_embed.cpu().numpy() vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy() vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy() byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy() byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy() ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
vlm_embed = vlm_embed.cpu()
vlm_mask = vlm_mask.cpu()
byt5_embed = byt5_embed.cpu()
byt5_mask = byt5_mask.cpu()
ocr_mask = ocr_mask.cpu()
for i, info in enumerate(infos): for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk: if self.cache_to_disk:
np.savez( tensors = {}
info.text_encoder_outputs_npz, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
vlm_embed=vlm_embed_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
vlm_mask=vlm_mask_i, for key in f.keys():
byt5_embed=byt5_embed_i, tensors[key] = f.get_tensor(key)
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i, ve = vlm_embed[i]
) be = byt5_embed[i]
tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve
tensors["vlm_mask"] = vlm_mask[i]
tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be
tensors["byt5_mask"] = byt5_mask[i]
tensors["ocr_mask"] = ocr_mask[i]
metadata = {
"architecture": "hunyuan_image",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i) info.text_encoder_outputs = (
vlm_embed[i].numpy(),
vlm_mask[i].numpy(),
byt5_embed[i].numpy(),
byt5_mask[i].numpy(),
ocr_mask[i].numpy(),
)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy): class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz" HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return (
os.path.splitext(absolute_path)[0] os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}" + f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX + self.cache_suffix
) )
def _get_architecture_name(self) -> str:
return "hunyuan_image"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -146,6 +146,7 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
def __init__( def __init__(
self, self,
@@ -162,19 +163,10 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
) )
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
return ( suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
os.path.splitext(image_abs_path)[0] return os.path.splitext(image_abs_path)[0] + suffix
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk: if not self.cache_to_disk:
return False return False
if not os.path.exists(npz_path): if not os.path.exists(npz_path):
@@ -183,13 +175,26 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "hidden_state" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
if "attention_mask" not in npz:
return False with MemoryEfficientSafeOpen(npz_path) as f:
if "input_ids" not in npz: keys = f.keys()
return False if not _find_tensor_by_prefix(keys, "hidden_state"):
return False
if "attention_mask" not in keys:
return False
if "input_ids" not in keys:
return False
else:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -198,11 +203,22 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
""" """
Load outputs from a npz file Load outputs from a npz/safetensors file
Returns: Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask List[np.ndarray]: hidden_state, input_ids, attention_mask
""" """
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
attention_mask = f.get_tensor("attention_mask").numpy()
input_ids = f.get_tensor("input_ids").numpy()
return [hidden_state, input_ids, attention_mask]
data = np.load(npz_path) data = np.load(npz_path)
hidden_state = data["hidden_state"] hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"] attention_mask = data["attention_mask"]
@@ -217,16 +233,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
text_encoding_strategy: TextEncodingStrategy, text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo], batch: List[train_util.ImageInfo],
) -> None: ) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
@@ -252,37 +258,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
) )
) )
if hidden_state.dtype != torch.float32: if self.cache_format == "safetensors":
hidden_state = hidden_state.float() self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
else:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy() hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S) attention_mask = attention_masks.cpu().numpy()
input_ids = input_ids.cpu().numpy() # (B, S) input_ids_np = input_ids.cpu().numpy()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids_np[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state = hidden_state.cpu()
input_ids = input_ids.cpu()
attention_mask = attention_masks.cpu()
for i, info in enumerate(batch): for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
if self.cache_to_disk: if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez( tensors = {}
info.text_encoder_outputs_npz, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
hidden_state=hidden_state_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
attention_mask=attention_mask_i, for key in f.keys():
input_ids=input_ids_i, tensors[key] = f.get_tensor(key)
)
hs = hidden_state[i]
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
tensors["attention_mask"] = attention_mask[i]
tensors["input_ids"] = input_ids[i]
metadata = {
"architecture": "lumina",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
info.text_encoder_outputs = [ info.text_encoder_outputs = [
hidden_state_i, hidden_state[i].numpy(),
input_ids_i, input_ids[i].numpy(),
attention_mask_i, attention_mask[i].numpy(),
] ]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy): class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
def __init__( def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
@@ -291,7 +335,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path( def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int] self, absolute_path: str, image_size: Tuple[int, int]
@@ -299,9 +343,12 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
return ( return (
os.path.splitext(absolute_path)[0] os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}" + f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + self.cache_suffix
) )
def _get_architecture_name(self) -> str:
return "lumina"
def is_disk_cached_latents_expected( def is_disk_cached_latents_expected(
self, self,
bucket_reso: Tuple[int, int], bucket_reso: Tuple[int, int],

View File

@@ -138,24 +138,32 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz" SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
SD_LATENTS_ST_SUFFIX = "_sd.safetensors"
SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(
self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return self.suffix if self.cache_format == "safetensors":
return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX
else:
return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz if self.cache_format != "safetensors":
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX # support old .npz
if os.path.exists(old_npz_file): old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
return old_npz_file if os.path.exists(old_npz_file):
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "sd" if self.sd else "sdxl"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -255,6 +255,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors"
def __init__( def __init__(
self, self,
@@ -270,7 +271,8 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.apply_t5_attn_mask = apply_t5_attn_mask self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str): def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk: if not self.cache_to_disk:
@@ -281,27 +283,54 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "lg_out" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
if "lg_pooled" not in npz:
return False with MemoryEfficientSafeOpen(npz_path) as f:
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used keys = f.keys()
return False if not _find_tensor_by_prefix(keys, "lg_out"):
if "apply_lg_attn_mask" not in npz: return False
return False if not _find_tensor_by_prefix(keys, "lg_pooled"):
if "t5_out" not in npz: return False
return False if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys:
if "t5_attn_mask" not in npz: return False
return False if not _find_tensor_by_prefix(keys, "t5_out"):
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] return False
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: if "t5_attn_mask" not in keys:
return False return False
if "apply_t5_attn_mask" not in npz: if "apply_lg_attn_mask" not in keys:
return False return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] apply_lg = f.get_tensor("apply_lg_attn_mask").item()
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: if bool(apply_lg) != self.apply_lg_attn_mask:
return False return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz:
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -309,6 +338,20 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy()
lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy()
g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
data = np.load(npz_path) data = np.load(npz_path)
lg_out = data["lg_out"] lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"] lg_pooled = data["lg_pooled"]
@@ -339,65 +382,127 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
enable_dropout=False, enable_dropout=False,
) )
if lg_out.dtype == torch.bfloat16: l_attn_mask_tokens = tokens_and_masks[3]
lg_out = lg_out.float() g_attn_mask_tokens = tokens_and_masks[4]
if lg_pooled.dtype == torch.bfloat16: t5_attn_mask_tokens = tokens_and_masks[5]
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy() if self.cache_format == "safetensors":
lg_pooled = lg_pooled.cpu().numpy() self._cache_batch_outputs_safetensors(
t5_out = t5_out.cpu().numpy() lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
)
else:
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
l_attn_mask = tokens_and_masks[3].cpu().numpy() lg_out = lg_out.cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy() lg_pooled = lg_pooled.cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy() t5_out = t5_out.cpu().numpy()
l_attn_mask = l_attn_mask_tokens.cpu().numpy()
g_attn_mask = g_attn_mask_tokens.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(
self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
lg_out = lg_out.cpu()
t5_out = t5_out.cpu()
lg_pooled = lg_pooled.cpu()
l_attn_mask = l_attn_mask_tokens.cpu()
g_attn_mask = g_attn_mask_tokens.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos): for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk: if self.cache_to_disk:
np.savez( tensors = {}
info.text_encoder_outputs_npz, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
lg_out=lg_out_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
lg_pooled=lg_pooled_i, for key in f.keys():
t5_out=t5_out_i, tensors[key] = f.get_tensor(key)
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i, lg_out_i = lg_out[i]
t5_attn_mask=t5_attn_mask_i, t5_out_i = t5_out[i]
apply_lg_attn_mask=apply_lg_attn_mask, lg_pooled_i = lg_pooled[i]
apply_t5_attn_mask=apply_t5_attn_mask, tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i
) tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i
tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i
tensors["clip_l_attn_mask"] = l_attn_mask[i]
tensors["clip_g_attn_mask"] = g_attn_mask[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool)
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "sd3",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) lg_out[i].numpy(),
t5_out[i].numpy(),
lg_pooled[i].numpy(),
l_attn_mask[i].numpy(),
g_attn_mask[i].numpy(),
t5_attn_mask[i].numpy(),
)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy): class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property @property
def cache_suffix(self) -> str: def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return (
os.path.splitext(absolute_path)[0] os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}" + f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + self.cache_suffix
) )
def _get_architecture_name(self) -> str:
return "sd3"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -221,6 +221,7 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors"
def __init__( def __init__(
self, self,
@@ -233,7 +234,8 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str: def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str): def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk: if not self.cache_to_disk:
@@ -244,9 +246,22 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
try: try:
npz = np.load(npz_path) if npz_path.endswith(".safetensors"):
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: from library.safetensors_utils import MemoryEfficientSafeOpen
return False from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state1"):
return False
if not _find_tensor_by_prefix(keys, "hidden_state2"):
return False
if not _find_tensor_by_prefix(keys, "pool2"):
return False
else:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -254,6 +269,17 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy()
hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy()
pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy()
return [hidden_state1, hidden_state2, pool2]
data = np.load(npz_path) data = np.load(npz_path)
hidden_state1 = data["hidden_state1"] hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"] hidden_state2 = data["hidden_state2"]
@@ -279,28 +305,68 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2] tokenize_strategy, models, [tokens1, tokens2]
) )
if hidden_state1.dtype == torch.bfloat16: if self.cache_format == "safetensors":
hidden_state1 = hidden_state1.float() self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos)
if hidden_state2.dtype == torch.bfloat16: else:
hidden_state2 = hidden_state2.float() if hidden_state1.dtype == torch.bfloat16:
if pool2.dtype == torch.bfloat16: hidden_state1 = hidden_state1.float()
pool2 = pool2.float() if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy() hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy() hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy() pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
for i, info in enumerate(infos): for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk: if self.cache_to_disk:
np.savez( tensors = {}
info.text_encoder_outputs_npz, # merge existing file if partial
hidden_state1=hidden_state1_i, if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
hidden_state2=hidden_state2_i, with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
pool2=pool2_i, for key in f.keys():
) tensors[key] = f.get_tensor(key)
hs1 = hidden_state1[i]
hs2 = hidden_state2[i]
p2 = pool2[i]
tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1
tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2
tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2
metadata = {
"architecture": "sdxl",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else: else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] info.text_encoder_outputs = [
hidden_state1[i].numpy(),
hidden_state2[i].numpy(),
pool2[i].numpy(),
]

View File

@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all( return all(
[ [
not ( not (
subset.caption_dropout_rate > 0 and not cache_supports_dropout subset.caption_dropout_rate > 0
and not cache_supports_dropout
or subset.shuffle_caption or subset.shuffle_caption
or subset.token_warmup_step > 0 or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0 or subset.caption_tag_dropout_rate > 0
@@ -4471,7 +4472,10 @@ def verify_training_args(args: argparse.Namespace):
Verify training arguments. Also reflect highvram option to global variable Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
""" """
from library.strategy_base import set_cache_format
enable_high_vram(args) enable_high_vram(args)
set_cache_format(args.cache_format)
if args.v2 and args.clip_skip is not None: if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -4637,6 +4641,14 @@ def add_dataset_arguments(
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", " / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
) )
parser.add_argument(
"--cache_format",
type=str,
default="npz",
choices=["npz", "safetensors"],
help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O"
" / latentおよびtext encoder出力キャッシュの保存形式デフォルト: npz。safetensorsはネイティブdtype例: bf16で保存し、ファイルサイズ削減と高速化が可能",
)
parser.add_argument( parser.add_argument(
"--enable_bucket", "--enable_bucket",
action="store_true", action="store_true",

View File

@@ -69,6 +69,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
strategy_base.set_cache_format(args.cache_format)
if is_sd or is_sdxl: if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else: else:

View File

@@ -156,6 +156,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
text_encoder.eval() text_encoder.eval()
# build text encoder outputs caching strategy # build text encoder outputs caching strategy
strategy_base.set_cache_format(args.cache_format)
if is_sdxl: if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions