diff --git a/library/strategy_anima.py b/library/strategy_anima.py index d89df5b9..cbf21029 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -155,6 +155,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): """ ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz" + ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors" def __init__( self, @@ -166,7 +167,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + return os.path.splitext(image_abs_path)[0] + suffix def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: if not self.cache_to_disk: @@ -177,17 +179,34 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True try: - npz = np.load(npz_path) - if "prompt_embeds" not in npz: - return False - if "attn_mask" not in npz: - return False - if "t5_input_ids" not in npz: - return False - if "t5_attn_mask" not in npz: - return False - if "caption_dropout_rate" not in npz: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "prompt_embeds"): + return False + if "attn_mask" not in keys: + return False + if "t5_input_ids" not in keys: + return False + if "t5_attn_mask" not in keys: + return False + if "caption_dropout_rate" not in keys: + return False + else: + npz = np.load(npz_path) + if "prompt_embeds" not in npz: + return False + if "attn_mask" not in npz: + return False + if "t5_input_ids" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + if "caption_dropout_rate" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -195,6 +214,19 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy() + attn_mask = f.get_tensor("attn_mask").numpy() + t5_input_ids = f.get_tensor("t5_input_ids").numpy() + t5_attn_mask = f.get_tensor("t5_attn_mask").numpy() + caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy() + return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate] + data = np.load(npz_path) prompt_embeds = data["prompt_embeds"] attn_mask = data["attn_mask"] @@ -219,32 +251,75 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokenize_strategy, models, tokens_and_masks ) - # Convert to numpy for caching - if prompt_embeds.dtype == torch.bfloat16: - prompt_embeds = prompt_embeds.float() - prompt_embeds = prompt_embeds.cpu().numpy() - attn_mask = attn_mask.cpu().numpy() - t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32) - t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32) + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos) + else: + # Convert to numpy for caching + if prompt_embeds.dtype == torch.bfloat16: + prompt_embeds = prompt_embeds.float() + prompt_embeds = prompt_embeds.cpu().numpy() + attn_mask = attn_mask.cpu().numpy() + t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32) + t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32) + + for i, info in enumerate(infos): + prompt_embeds_i = prompt_embeds[i] + attn_mask_i = attn_mask[i] + t5_input_ids_i = t5_input_ids[i] + t5_attn_mask_i = t5_attn_mask[i] + caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32) + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + prompt_embeds=prompt_embeds_i, + attn_mask=attn_mask_i, + t5_input_ids=t5_input_ids_i, + t5_attn_mask=t5_attn_mask_i, + caption_dropout_rate=caption_dropout_rate, + ) + else: + info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate) + + def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + prompt_embeds = prompt_embeds.cpu() + attn_mask = attn_mask.cpu() + t5_input_ids = t5_input_ids.cpu().to(torch.int32) + t5_attn_mask = t5_attn_mask.cpu().to(torch.int32) for i, info in enumerate(infos): - prompt_embeds_i = prompt_embeds[i] - attn_mask_i = attn_mask[i] - t5_input_ids_i = t5_input_ids[i] - t5_attn_mask_i = t5_attn_mask[i] - caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32) - if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - prompt_embeds=prompt_embeds_i, - attn_mask=attn_mask_i, - t5_input_ids=t5_input_ids_i, - t5_attn_mask=t5_attn_mask_i, - caption_dropout_rate=caption_dropout_rate, - ) + tensors = {} + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + pe = prompt_embeds[i] + tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe + tensors["attn_mask"] = attn_mask[i] + tensors["t5_input_ids"] = t5_input_ids[i] + tensors["t5_attn_mask"] = t5_attn_mask[i] + tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32) + + metadata = { + "architecture": "anima", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: - info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate) + caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32) + info.text_encoder_outputs = ( + prompt_embeds[i].numpy(), + attn_mask[i].numpy(), + t5_input_ids[i].numpy(), + t5_attn_mask[i].numpy(), + caption_dropout_rate, + ) class AnimaLatentsCachingStrategy(LatentsCachingStrategy): @@ -255,16 +330,20 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy): """ ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz" + ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) @property def cache_suffix(self) -> str: - return self.ANIMA_LATENTS_NPZ_SUFFIX + return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix + + def _get_architecture_name(self) -> str: + return "anima" def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) diff --git a/library/strategy_base.py b/library/strategy_base.py index 5a043342..504728aa 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -2,7 +2,7 @@ import os import re -from typing import Any, List, Optional, Tuple, Union, Callable +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -19,6 +19,48 @@ import logging logger = logging.getLogger(__name__) +LATENTS_CACHE_FORMAT_VERSION = "1.0.1" +TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1" + +# global cache format setting: "npz" or "safetensors" +_cache_format: str = "npz" + + +def set_cache_format(cache_format: str) -> None: + global _cache_format + _cache_format = cache_format + + +def get_cache_format() -> str: + return _cache_format + +_TORCH_DTYPE_TO_STR = { + torch.float64: "float64", + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.int64: "int64", + torch.int32: "int32", + torch.int16: "int16", + torch.int8: "int8", + torch.uint8: "uint8", + torch.bool: "bool", +} + +_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16} + + +def _dtype_to_str(dtype: torch.dtype) -> str: + return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", "")) + + +def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]: + """Find a tensor key that starts with the given prefix. Returns the first match or None.""" + for key in tensors_keys: + if key.startswith(prefix) or key == prefix: + return key + return None + class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class @@ -362,6 +404,10 @@ class TextEncoderOutputsCachingStrategy: def is_weighted(self): return self._is_weighted + @property + def cache_format(self) -> str: + return get_cache_format() + def get_outputs_npz_path(self, image_abs_path: str) -> str: raise NotImplementedError @@ -407,6 +453,10 @@ class LatentsCachingStrategy: def batch_size(self): return self._batch_size + @property + def cache_format(self) -> str: + return get_cache_format() + @property def cache_suffix(self): raise NotImplementedError @@ -439,7 +489,7 @@ class LatentsCachingStrategy: Args: latents_stride: stride of latents bucket_reso: resolution of the bucket - npz_path: path to the npz file + npz_path: path to the npz/safetensors file flip_aug: whether to flip images apply_alpha_mask: whether to apply alpha mask multi_resolution: whether to use multi-resolution latents @@ -454,6 +504,11 @@ class LatentsCachingStrategy: if self.skip_disk_cache_validity_check: return True + if npz_path.endswith(".safetensors"): + return self._is_disk_cached_latents_expected_safetensors( + latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution + ) + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) # e.g. "_32x64", HxW @@ -476,6 +531,40 @@ class LatentsCachingStrategy: return True + def _is_disk_cached_latents_expected_safetensors( + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + st_path: str, + flip_aug: bool, + apply_alpha_mask: bool, + multi_resolution: bool = False, + ) -> bool: + from library.safetensors_utils import MemoryEfficientSafeOpen + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W) + reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x" + + try: + with MemoryEfficientSafeOpen(st_path) as f: + keys = f.keys() + latents_prefix = f"latents_{reso_tag}" + if not any(k.startswith(latents_prefix) for k in keys): + return False + if flip_aug: + flipped_prefix = f"latents_flipped_{reso_tag}" + if not any(k.startswith(flipped_prefix) for k in keys): + return False + if apply_alpha_mask: + mask_prefix = f"alpha_mask_{reso_tag}" + if not any(k.startswith(mask_prefix) for k in keys): + return False + except Exception as e: + logger.error(f"Error loading file: {st_path}") + raise e + + return True + # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( self, @@ -571,7 +660,7 @@ class LatentsCachingStrategy: """ Args: latents_stride (Optional[int]): Stride for latents. If None, load all latents. - npz_path (str): Path to the npz file. + npz_path (str): Path to the npz/safetensors file. bucket_reso (Tuple[int, int]): The resolution of the bucket. Returns: @@ -583,6 +672,9 @@ class LatentsCachingStrategy: Optional[np.ndarray] ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ + if npz_path.endswith(".safetensors"): + return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso) + if latents_stride is None: key_reso_suffix = "" else: @@ -609,6 +701,39 @@ class LatentsCachingStrategy: alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + def _load_latents_from_disk_safetensors( + self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + from library.safetensors_utils import MemoryEfficientSafeOpen + + if latents_stride is None: + reso_tag = "1x" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) + reso_tag = f"1x{latents_size[0]}x{latents_size[1]}" + + with MemoryEfficientSafeOpen(st_path) as f: + keys = f.keys() + + latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}") + if latents_key is None: + raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}") + latents = f.get_tensor(latents_key).numpy() + + original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}") + original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0] + + crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}") + crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0] + + flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}") + flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None + + alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}") + alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None + + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + def save_latents_to_disk( self, npz_path, @@ -621,17 +746,23 @@ class LatentsCachingStrategy: ): """ Args: - npz_path (str): Path to the npz file. + npz_path (str): Path to the npz/safetensors file. latents_tensor (torch.Tensor): Latent tensor original_size (List[int]): Original size of the image crop_ltrb (List[int]): Crop left top right bottom flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor alpha_mask (Optional[torch.Tensor]): Alpha mask - key_reso_suffix (str): Key resolution suffix + key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz) Returns: None """ + if npz_path.endswith(".safetensors"): + self._save_latents_to_disk_safetensors( + npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix + ) + return + kwargs = {} if os.path.exists(npz_path): @@ -640,7 +771,7 @@ class LatentsCachingStrategy: for key in npz.files: kwargs[key] = npz[key] - # TODO float() is needed if vae is in bfloat16. Remove it if vae is float16. + # float() is needed because npz doesn't support bfloat16 kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() kwargs["original_size" + key_reso_suffix] = np.array(original_size) kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) @@ -649,3 +780,59 @@ class LatentsCachingStrategy: if alpha_mask is not None: kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() np.savez(npz_path, **kwargs) + + def _save_latents_to_disk_safetensors( + self, + st_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", + ): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + + latents_tensor = latents_tensor.cpu() + latents_size = latents_tensor.shape[-2:] # H, W + reso_tag = f"1x{latents_size[0]}x{latents_size[1]}" + dtype_str = _dtype_to_str(latents_tensor.dtype) + + # NaN check and zero replacement + if torch.isnan(latents_tensor).any(): + latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0) + + tensors: Dict[str, torch.Tensor] = {} + + # load existing file and merge (for multi-resolution) + if os.path.exists(st_path): + with MemoryEfficientSafeOpen(st_path) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor + tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32) + tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32) + + if flipped_latents_tensor is not None: + flipped_latents_tensor = flipped_latents_tensor.cpu() + if torch.isnan(flipped_latents_tensor).any(): + flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0) + tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor + + if alpha_mask is not None: + alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask) + tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor + + metadata = { + "architecture": self._get_architecture_name(), + "width": str(latents_size[1]), + "height": str(latents_size[0]), + "format_version": LATENTS_CACHE_FORMAT_VERSION, + } + + mem_eff_save_file(tensors, st_path, metadata=metadata) + + def _get_architecture_name(self) -> str: + """Override in subclasses to return the architecture name for safetensors metadata.""" + return "unknown" diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5e65927f..f88881a9 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -87,6 +87,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors" def __init__( self, @@ -102,7 +103,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): self.warn_fp8_weights = False def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + return os.path.splitext(image_abs_path)[0] + suffix def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: @@ -113,20 +115,40 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True try: - npz = np.load(npz_path) - if "l_pooled" not in npz: - return False - if "t5_out" not in npz: - return False - if "txt_ids" not in npz: - return False - if "t5_attn_mask" not in npz: - return False - if "apply_t5_attn_mask" not in npz: - return False - npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] - if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "l_pooled"): + return False + if not _find_tensor_by_prefix(keys, "t5_out"): + return False + if not _find_tensor_by_prefix(keys, "txt_ids"): + return False + if "t5_attn_mask" not in keys: + return False + if "apply_t5_attn_mask" not in keys: + return False + apply_t5 = f.get_tensor("apply_t5_attn_mask").item() + if bool(apply_t5) != self.apply_t5_attn_mask: + return False + else: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -134,6 +156,18 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy() + t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy() + txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy() + t5_attn_mask = f.get_tensor("t5_attn_mask").numpy() + return [l_pooled, t5_out, txt_ids, t5_attn_mask] + data = np.load(npz_path) l_pooled = data["l_pooled"] t5_out = data["t5_out"] @@ -161,56 +195,100 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) - if l_pooled.dtype == torch.bfloat16: - l_pooled = l_pooled.float() - if t5_out.dtype == torch.bfloat16: - t5_out = t5_out.float() - if txt_ids.dtype == torch.bfloat16: - txt_ids = txt_ids.float() + t5_attn_mask_tokens = tokens_and_masks[2] - l_pooled = l_pooled.cpu().numpy() - t5_out = t5_out.cpu().numpy() - txt_ids = txt_ids.cpu().numpy() - t5_attn_mask = tokens_and_masks[2].cpu().numpy() + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos) + else: + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = t5_attn_mask_tokens.cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + + def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + l_pooled = l_pooled.cpu() + t5_out = t5_out.cpu() + txt_ids = txt_ids.cpu() + t5_attn_mask = t5_attn_mask_tokens.cpu() for i, info in enumerate(infos): - l_pooled_i = l_pooled[i] - t5_out_i = t5_out[i] - txt_ids_i = txt_ids[i] - t5_attn_mask_i = t5_attn_mask[i] - apply_t5_attn_mask_i = self.apply_t5_attn_mask - if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - l_pooled=l_pooled_i, - t5_out=t5_out_i, - txt_ids=txt_ids_i, - t5_attn_mask=t5_attn_mask_i, - apply_t5_attn_mask=apply_t5_attn_mask_i, - ) + tensors = {} + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + lp = l_pooled[i] + to = t5_out[i] + ti = txt_ids[i] + tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp + tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to + tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti + tensors["t5_attn_mask"] = t5_attn_mask[i] + tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool) + + metadata = { + "architecture": "flux", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: - # it's fine that attn mask is not None. it's overwritten before calling the model if necessary - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy()) class FluxLatentsCachingStrategy(LatentsCachingStrategy): FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) @property def cache_suffix(self) -> str: - return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + + self.cache_suffix ) + def _get_architecture_name(self) -> str: + return "flux" + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) diff --git a/library/strategy_hunyuan_image.py b/library/strategy_hunyuan_image.py index 5c704728..f8584fe3 100644 --- a/library/strategy_hunyuan_image.py +++ b/library/strategy_hunyuan_image.py @@ -81,16 +81,17 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy): class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz" + HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) def get_outputs_npz_path(self, image_abs_path: str) -> str: + suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX return ( - os.path.splitext(image_abs_path)[0] - + HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + os.path.splitext(image_abs_path)[0] + suffix ) def is_disk_cached_outputs_expected(self, npz_path: str): @@ -102,17 +103,34 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr return True try: - npz = np.load(npz_path) - if "vlm_embed" not in npz: - return False - if "vlm_mask" not in npz: - return False - if "byt5_embed" not in npz: - return False - if "byt5_mask" not in npz: - return False - if "ocr_mask" not in npz: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "vlm_embed"): + return False + if "vlm_mask" not in keys: + return False + if not _find_tensor_by_prefix(keys, "byt5_embed"): + return False + if "byt5_mask" not in keys: + return False + if "ocr_mask" not in keys: + return False + else: + npz = np.load(npz_path) + if "vlm_embed" not in npz: + return False + if "vlm_mask" not in npz: + return False + if "byt5_embed" not in npz: + return False + if "byt5_mask" not in npz: + return False + if "ocr_mask" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -120,6 +138,19 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy() + vlm_mask = f.get_tensor("vlm_mask").numpy() + byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy() + byt5_mask = f.get_tensor("byt5_mask").numpy() + ocr_mask = f.get_tensor("ocr_mask").numpy() + return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask] + data = np.load(npz_path) vln_embed = data["vlm_embed"] vlm_mask = data["vlm_mask"] @@ -140,54 +171,102 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr tokenize_strategy, models, tokens_and_masks ) - if vlm_embed.dtype == torch.bfloat16: - vlm_embed = vlm_embed.float() - if byt5_embed.dtype == torch.bfloat16: - byt5_embed = byt5_embed.float() + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos) + else: + if vlm_embed.dtype == torch.bfloat16: + vlm_embed = vlm_embed.float() + if byt5_embed.dtype == torch.bfloat16: + byt5_embed = byt5_embed.float() - vlm_embed = vlm_embed.cpu().numpy() - vlm_mask = vlm_mask.cpu().numpy() - byt5_embed = byt5_embed.cpu().numpy() - byt5_mask = byt5_mask.cpu().numpy() - ocr_mask = ocr_mask.cpu().numpy() + vlm_embed = vlm_embed.cpu().numpy() + vlm_mask = vlm_mask.cpu().numpy() + byt5_embed = byt5_embed.cpu().numpy() + byt5_mask = byt5_mask.cpu().numpy() + ocr_mask = ocr_mask.cpu().numpy() + + for i, info in enumerate(infos): + vlm_embed_i = vlm_embed[i] + vlm_mask_i = vlm_mask[i] + byt5_embed_i = byt5_embed[i] + byt5_mask_i = byt5_mask[i] + ocr_mask_i = ocr_mask[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + vlm_embed=vlm_embed_i, + vlm_mask=vlm_mask_i, + byt5_embed=byt5_embed_i, + byt5_mask=byt5_mask_i, + ocr_mask=ocr_mask_i, + ) + else: + info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i) + + def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + vlm_embed = vlm_embed.cpu() + vlm_mask = vlm_mask.cpu() + byt5_embed = byt5_embed.cpu() + byt5_mask = byt5_mask.cpu() + ocr_mask = ocr_mask.cpu() for i, info in enumerate(infos): - vlm_embed_i = vlm_embed[i] - vlm_mask_i = vlm_mask[i] - byt5_embed_i = byt5_embed[i] - byt5_mask_i = byt5_mask[i] - ocr_mask_i = ocr_mask[i] - if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - vlm_embed=vlm_embed_i, - vlm_mask=vlm_mask_i, - byt5_embed=byt5_embed_i, - byt5_mask=byt5_mask_i, - ocr_mask=ocr_mask_i, - ) + tensors = {} + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + ve = vlm_embed[i] + be = byt5_embed[i] + tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve + tensors["vlm_mask"] = vlm_mask[i] + tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be + tensors["byt5_mask"] = byt5_mask[i] + tensors["ocr_mask"] = ocr_mask[i] + + metadata = { + "architecture": "hunyuan_image", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: - info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i) + info.text_encoder_outputs = ( + vlm_embed[i].numpy(), + vlm_mask[i].numpy(), + byt5_embed[i].numpy(), + byt5_mask[i].numpy(), + ocr_mask[i].numpy(), + ) class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy): HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz" + HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) @property def cache_suffix(self) -> str: - return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX + return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX + + self.cache_suffix ) + def _get_architecture_name(self) -> str: + return "hunyuan_image" + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 964d9f7a..c2638fee 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -146,6 +146,7 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" + LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors" def __init__( self, @@ -162,19 +163,10 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) ) def get_outputs_npz_path(self, image_abs_path: str) -> str: - return ( - os.path.splitext(image_abs_path)[0] - + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - ) + suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + return os.path.splitext(image_abs_path)[0] + suffix def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: - """ - Args: - npz_path (str): Path to the npz file. - - Returns: - bool: True if the npz file is expected to be cached. - """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -183,13 +175,26 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) return True try: - npz = np.load(npz_path) - if "hidden_state" not in npz: - return False - if "attention_mask" not in npz: - return False - if "input_ids" not in npz: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "hidden_state"): + return False + if "attention_mask" not in keys: + return False + if "input_ids" not in keys: + return False + else: + npz = np.load(npz_path) + if "hidden_state" not in npz: + return False + if "attention_mask" not in npz: + return False + if "input_ids" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -198,11 +203,22 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: """ - Load outputs from a npz file + Load outputs from a npz/safetensors file Returns: List[np.ndarray]: hidden_state, input_ids, attention_mask """ + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy() + attention_mask = f.get_tensor("attention_mask").numpy() + input_ids = f.get_tensor("input_ids").numpy() + return [hidden_state, input_ids, attention_mask] + data = np.load(npz_path) hidden_state = data["hidden_state"] attention_mask = data["attention_mask"] @@ -217,16 +233,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) text_encoding_strategy: TextEncodingStrategy, batch: List[train_util.ImageInfo], ) -> None: - """ - Args: - tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy - models (List[Any]): Text encoders - text_encoding_strategy (LuminaTextEncodingStrategy): - infos (List): List of ImageInfo - - Returns: - None - """ assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) @@ -252,37 +258,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) ) ) - if hidden_state.dtype != torch.float32: - hidden_state = hidden_state.float() + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch) + else: + if hidden_state.dtype != torch.float32: + hidden_state = hidden_state.float() - hidden_state = hidden_state.cpu().numpy() - attention_mask = attention_masks.cpu().numpy() # (B, S) - input_ids = input_ids.cpu().numpy() # (B, S) + hidden_state = hidden_state.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() + input_ids_np = input_ids.cpu().numpy() + for i, info in enumerate(batch): + hidden_state_i = hidden_state[i] + attention_mask_i = attention_mask[i] + input_ids_i = input_ids_np[i] + + if self.cache_to_disk: + assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" + np.savez( + info.text_encoder_outputs_npz, + hidden_state=hidden_state_i, + attention_mask=attention_mask_i, + input_ids=input_ids_i, + ) + else: + info.text_encoder_outputs = [ + hidden_state_i, + input_ids_i, + attention_mask_i, + ] + + def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + hidden_state = hidden_state.cpu() + input_ids = input_ids.cpu() + attention_mask = attention_masks.cpu() for i, info in enumerate(batch): - hidden_state_i = hidden_state[i] - attention_mask_i = attention_mask[i] - input_ids_i = input_ids[i] - if self.cache_to_disk: assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" - np.savez( - info.text_encoder_outputs_npz, - hidden_state=hidden_state_i, - attention_mask=attention_mask_i, - input_ids=input_ids_i, - ) + tensors = {} + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + hs = hidden_state[i] + tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs + tensors["attention_mask"] = attention_mask[i] + tensors["input_ids"] = input_ids[i] + + metadata = { + "architecture": "lumina", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: info.text_encoder_outputs = [ - hidden_state_i, - input_ids_i, - attention_mask_i, + hidden_state[i].numpy(), + input_ids[i].numpy(), + attention_mask[i].numpy(), ] class LuminaLatentsCachingStrategy(LatentsCachingStrategy): LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" + LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors" def __init__( self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool @@ -291,7 +335,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): @property def cache_suffix(self) -> str: - return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX def get_latents_npz_path( self, absolute_path: str, image_size: Tuple[int, int] @@ -299,9 +343,12 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): return ( os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + + self.cache_suffix ) + def _get_architecture_name(self) -> str: + return "lumina" + def is_disk_cached_latents_expected( self, bucket_reso: Tuple[int, int], diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 4521ae8d..9f68a8d6 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -138,24 +138,32 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" SD_LATENTS_NPZ_SUFFIX = "_sd.npz" SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + SD_LATENTS_ST_SUFFIX = "_sd.safetensors" + SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors" - def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__( + self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) self.sd = sd - self.suffix = ( - SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX - ) @property def cache_suffix(self) -> str: - return self.suffix + if self.cache_format == "safetensors": + return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX + else: + return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - # support old .npz - old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX - if os.path.exists(old_npz_file): - return old_npz_file - return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + if self.cache_format != "safetensors": + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix + + def _get_architecture_name(self) -> str: + return "sd" if self.sd else "sdxl" def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 1d55fe21..934da2fd 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -255,6 +255,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors" def __init__( self, @@ -270,7 +271,8 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + return os.path.splitext(image_abs_path)[0] + suffix def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: @@ -281,27 +283,54 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True try: - npz = np.load(npz_path) - if "lg_out" not in npz: - return False - if "lg_pooled" not in npz: - return False - if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used - return False - if "apply_lg_attn_mask" not in npz: - return False - if "t5_out" not in npz: - return False - if "t5_attn_mask" not in npz: - return False - npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] - if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: - return False - if "apply_t5_attn_mask" not in npz: - return False - npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] - if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "lg_out"): + return False + if not _find_tensor_by_prefix(keys, "lg_pooled"): + return False + if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys: + return False + if not _find_tensor_by_prefix(keys, "t5_out"): + return False + if "t5_attn_mask" not in keys: + return False + if "apply_lg_attn_mask" not in keys: + return False + apply_lg = f.get_tensor("apply_lg_attn_mask").item() + if bool(apply_lg) != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in keys: + return False + apply_t5 = f.get_tensor("apply_t5_attn_mask").item() + if bool(apply_t5) != self.apply_t5_attn_mask: + return False + else: + npz = np.load(npz_path) + if "lg_out" not in npz: + return False + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: + return False + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -309,6 +338,20 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy() + lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy() + t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy() + l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy() + g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy() + t5_attn_mask = f.get_tensor("t5_attn_mask").numpy() + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] @@ -339,65 +382,127 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): enable_dropout=False, ) - if lg_out.dtype == torch.bfloat16: - lg_out = lg_out.float() - if lg_pooled.dtype == torch.bfloat16: - lg_pooled = lg_pooled.float() - if t5_out.dtype == torch.bfloat16: - t5_out = t5_out.float() + l_attn_mask_tokens = tokens_and_masks[3] + g_attn_mask_tokens = tokens_and_masks[4] + t5_attn_mask_tokens = tokens_and_masks[5] - lg_out = lg_out.cpu().numpy() - lg_pooled = lg_pooled.cpu().numpy() - t5_out = t5_out.cpu().numpy() + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors( + lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos + ) + else: + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() - l_attn_mask = tokens_and_masks[3].cpu().numpy() - g_attn_mask = tokens_and_masks[4].cpu().numpy() - t5_attn_mask = tokens_and_masks[5].cpu().numpy() + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = l_attn_mask_tokens.cpu().numpy() + g_attn_mask = g_attn_mask_tokens.cpu().numpy() + t5_attn_mask = t5_attn_mask_tokens.cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] + lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + + def _cache_batch_outputs_safetensors( + self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos + ): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + lg_out = lg_out.cpu() + t5_out = t5_out.cpu() + lg_pooled = lg_pooled.cpu() + l_attn_mask = l_attn_mask_tokens.cpu() + g_attn_mask = g_attn_mask_tokens.cpu() + t5_attn_mask = t5_attn_mask_tokens.cpu() for i, info in enumerate(infos): - lg_out_i = lg_out[i] - t5_out_i = t5_out[i] - lg_pooled_i = lg_pooled[i] - l_attn_mask_i = l_attn_mask[i] - g_attn_mask_i = g_attn_mask[i] - t5_attn_mask_i = t5_attn_mask[i] - apply_lg_attn_mask = self.apply_lg_attn_mask - apply_t5_attn_mask = self.apply_t5_attn_mask - if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - lg_out=lg_out_i, - lg_pooled=lg_pooled_i, - t5_out=t5_out_i, - clip_l_attn_mask=l_attn_mask_i, - clip_g_attn_mask=g_attn_mask_i, - t5_attn_mask=t5_attn_mask_i, - apply_lg_attn_mask=apply_lg_attn_mask, - apply_t5_attn_mask=apply_t5_attn_mask, - ) + tensors = {} + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] + lg_pooled_i = lg_pooled[i] + tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i + tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i + tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i + tensors["clip_l_attn_mask"] = l_attn_mask[i] + tensors["clip_g_attn_mask"] = g_attn_mask[i] + tensors["t5_attn_mask"] = t5_attn_mask[i] + tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool) + tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool) + + metadata = { + "architecture": "sd3", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: - # it's fine that attn mask is not None. it's overwritten before calling the model if necessary - info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + info.text_encoder_outputs = ( + lg_out[i].numpy(), + t5_out[i].numpy(), + lg_pooled[i].numpy(), + l_attn_mask[i].numpy(), + g_attn_mask[i].numpy(), + t5_attn_mask[i].numpy(), + ) class Sd3LatentsCachingStrategy(LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors" def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) @property def cache_suffix(self) -> str: - return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + + self.cache_suffix ) + def _get_architecture_name(self) -> str: + return "sd3" + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6b3e2afa..176d7667 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -221,6 +221,7 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors" def __init__( self, @@ -233,7 +234,8 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) def get_outputs_npz_path(self, image_abs_path: str) -> str: - return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + return os.path.splitext(image_abs_path)[0] + suffix def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: @@ -244,9 +246,22 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True try: - npz = np.load(npz_path) - if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: - return False + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + if not _find_tensor_by_prefix(keys, "hidden_state1"): + return False + if not _find_tensor_by_prefix(keys, "hidden_state2"): + return False + if not _find_tensor_by_prefix(keys, "pool2"): + return False + else: + npz = np.load(npz_path) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -254,6 +269,17 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + if npz_path.endswith(".safetensors"): + from library.safetensors_utils import MemoryEfficientSafeOpen + from library.strategy_base import _find_tensor_by_prefix + + with MemoryEfficientSafeOpen(npz_path) as f: + keys = f.keys() + hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy() + hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy() + pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy() + return [hidden_state1, hidden_state2, pool2] + data = np.load(npz_path) hidden_state1 = data["hidden_state1"] hidden_state2 = data["hidden_state2"] @@ -279,28 +305,68 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokenize_strategy, models, [tokens1, tokens2] ) - if hidden_state1.dtype == torch.bfloat16: - hidden_state1 = hidden_state1.float() - if hidden_state2.dtype == torch.bfloat16: - hidden_state2 = hidden_state2.float() - if pool2.dtype == torch.bfloat16: - pool2 = pool2.float() + if self.cache_format == "safetensors": + self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos) + else: + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() - hidden_state1 = hidden_state1.cpu().numpy() - hidden_state2 = hidden_state2.cpu().numpy() - pool2 = pool2.cpu().numpy() + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] + + def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos): + from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen + from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION + + hidden_state1 = hidden_state1.cpu() + hidden_state2 = hidden_state2.cpu() + pool2 = pool2.cpu() for i, info in enumerate(infos): - hidden_state1_i = hidden_state1[i] - hidden_state2_i = hidden_state2[i] - pool2_i = pool2[i] - if self.cache_to_disk: - np.savez( - info.text_encoder_outputs_npz, - hidden_state1=hidden_state1_i, - hidden_state2=hidden_state2_i, - pool2=pool2_i, - ) + tensors = {} + # merge existing file if partial + if self.is_partial and os.path.exists(info.text_encoder_outputs_npz): + with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + hs1 = hidden_state1[i] + hs2 = hidden_state2[i] + p2 = pool2[i] + tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1 + tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2 + tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2 + + metadata = { + "architecture": "sdxl", + "caption1": info.caption, + "format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION, + } + mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata) else: - info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] + info.text_encoder_outputs = [ + hidden_state1[i].numpy(), + hidden_state2[i].numpy(), + pool2[i].numpy(), + ] diff --git a/library/train_util.py b/library/train_util.py index b65f06b9..0211f892 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset): return all( [ not ( - subset.caption_dropout_rate > 0 and not cache_supports_dropout + subset.caption_dropout_rate > 0 + and not cache_supports_dropout or subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0 @@ -4471,7 +4472,10 @@ def verify_training_args(args: argparse.Namespace): Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ + from library.strategy_base import set_cache_format + enable_high_vram(args) + set_cache_format(args.cache_format) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -4637,6 +4641,14 @@ def add_dataset_arguments( help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", ) + parser.add_argument( + "--cache_format", + type=str, + default="npz", + choices=["npz", "safetensors"], + help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O" + " / latentおよびtext encoder出力キャッシュの保存形式(デフォルト: npz)。safetensorsはネイティブdtype(例: bf16)で保存し、ファイルサイズ削減と高速化が可能", + ) parser.add_argument( "--enable_bucket", action="store_true", diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 5baddb5b..21e79889 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -69,6 +69,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + strategy_base.set_cache_format(args.cache_format) + if is_sd or is_sdxl: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 8e604292..ba4db1b4 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -156,6 +156,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: text_encoder.eval() # build text encoder outputs caching strategy + strategy_base.set_cache_format(args.cache_format) + if is_sdxl: text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions