Compare commits

..

9 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
Kohya S.
51435f1718 Merge pull request #2303 from kohya-ss/sd3
fix: improve numerical stability by conditionally using float32 in Anima with fp16 training
2026-04-02 12:40:48 +09:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
15 changed files with 301 additions and 942 deletions

View File

@@ -50,6 +50,12 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,12 @@ If you find this project helpful, please consider supporting its development via
### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -1,6 +1,7 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3])
torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5:
if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7:
if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# C
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12

View File

@@ -155,7 +155,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
"""
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
def __init__(
self,
@@ -167,8 +166,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
@@ -179,34 +177,17 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "prompt_embeds"):
return False
if "attn_mask" not in keys:
return False
if "t5_input_ids" not in keys:
return False
if "t5_attn_mask" not in keys:
return False
if "caption_dropout_rate" not in keys:
return False
else:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -214,19 +195,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
attn_mask = f.get_tensor("attn_mask").numpy()
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"]
@@ -251,75 +219,32 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, tokens_and_masks
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
else:
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
prompt_embeds = prompt_embeds.cpu()
attn_mask = attn_mask.cpu()
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pe = prompt_embeds[i]
tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
tensors["attn_mask"] = attn_mask[i]
tensors["t5_input_ids"] = t5_input_ids[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
metadata = {
"architecture": "anima",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
info.text_encoder_outputs = (
prompt_embeds[i].numpy(),
attn_mask[i].numpy(),
t5_input_ids[i].numpy(),
t5_attn_mask[i].numpy(),
caption_dropout_rate,
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -330,20 +255,16 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "anima"
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -2,7 +2,7 @@
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -19,48 +19,6 @@ import logging
logger = logging.getLogger(__name__)
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
# global cache format setting: "npz" or "safetensors"
_cache_format: str = "npz"
def set_cache_format(cache_format: str) -> None:
global _cache_format
_cache_format = cache_format
def get_cache_format() -> str:
return _cache_format
_TORCH_DTYPE_TO_STR = {
torch.float64: "float64",
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint8: "uint8",
torch.bool: "bool",
}
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
def _dtype_to_str(dtype: torch.dtype) -> str:
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
for key in tensors_keys:
if key.startswith(prefix) or key == prefix:
return key
return None
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -404,10 +362,6 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
@property
def cache_format(self) -> str:
return get_cache_format()
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
@@ -453,10 +407,6 @@ class LatentsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_format(self) -> str:
return get_cache_format()
@property
def cache_suffix(self):
raise NotImplementedError
@@ -489,7 +439,7 @@ class LatentsCachingStrategy:
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz/safetensors file
npz_path: path to the npz file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
@@ -504,11 +454,6 @@ class LatentsCachingStrategy:
if self.skip_disk_cache_validity_check:
return True
if npz_path.endswith(".safetensors"):
return self._is_disk_cached_latents_expected_safetensors(
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
@@ -531,40 +476,6 @@ class LatentsCachingStrategy:
return True
def _is_disk_cached_latents_expected_safetensors(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
st_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
from library.safetensors_utils import MemoryEfficientSafeOpen
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
try:
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_prefix = f"latents_{reso_tag}"
if not any(k.startswith(latents_prefix) for k in keys):
return False
if flip_aug:
flipped_prefix = f"latents_flipped_{reso_tag}"
if not any(k.startswith(flipped_prefix) for k in keys):
return False
if apply_alpha_mask:
mask_prefix = f"alpha_mask_{reso_tag}"
if not any(k.startswith(mask_prefix) for k in keys):
return False
except Exception as e:
logger.error(f"Error loading file: {st_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
@@ -660,7 +571,7 @@ class LatentsCachingStrategy:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz/safetensors file.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
@@ -672,9 +583,6 @@ class LatentsCachingStrategy:
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if npz_path.endswith(".safetensors"):
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -701,39 +609,6 @@ class LatentsCachingStrategy:
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def _load_latents_from_disk_safetensors(
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
from library.safetensors_utils import MemoryEfficientSafeOpen
if latents_stride is None:
reso_tag = "1x"
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
if latents_key is None:
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
latents = f.get_tensor(latents_key).numpy()
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
@@ -746,23 +621,17 @@ class LatentsCachingStrategy:
):
"""
Args:
npz_path (str): Path to the npz/safetensors file.
npz_path (str): Path to the npz file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
key_reso_suffix (str): Key resolution suffix
Returns:
None
"""
if npz_path.endswith(".safetensors"):
self._save_latents_to_disk_safetensors(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
)
return
kwargs = {}
if os.path.exists(npz_path):
@@ -771,7 +640,7 @@ class LatentsCachingStrategy:
for key in npz.files:
kwargs[key] = npz[key]
# float() is needed because npz doesn't support bfloat16
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
@@ -780,59 +649,3 @@ class LatentsCachingStrategy:
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
def _save_latents_to_disk_safetensors(
self,
st_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
latents_tensor = latents_tensor.cpu()
latents_size = latents_tensor.shape[-2:] # H, W
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
dtype_str = _dtype_to_str(latents_tensor.dtype)
# NaN check and zero replacement
if torch.isnan(latents_tensor).any():
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
tensors: Dict[str, torch.Tensor] = {}
# load existing file and merge (for multi-resolution)
if os.path.exists(st_path):
with MemoryEfficientSafeOpen(st_path) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
if flipped_latents_tensor is not None:
flipped_latents_tensor = flipped_latents_tensor.cpu()
if torch.isnan(flipped_latents_tensor).any():
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
if alpha_mask is not None:
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
metadata = {
"architecture": self._get_architecture_name(),
"width": str(latents_size[1]),
"height": str(latents_size[0]),
"format_version": LATENTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, st_path, metadata=metadata)
def _get_architecture_name(self) -> str:
"""Override in subclasses to return the architecture name for safetensors metadata."""
return "unknown"

View File

@@ -87,7 +87,6 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors"
def __init__(
self,
@@ -103,8 +102,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -115,40 +113,20 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "l_pooled"):
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if not _find_tensor_by_prefix(keys, "txt_ids"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -156,18 +134,6 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
@@ -195,100 +161,56 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
t5_attn_mask_tokens = tokens_and_masks[2]
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos)
else:
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lp = l_pooled[i]
to = t5_out[i]
ti = txt_ids[i]
tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp
tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to
tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "flux",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy())
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.cache_suffix
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def _get_architecture_name(self) -> str:
return "flux"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -81,17 +81,16 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False,
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return (
os.path.splitext(image_abs_path)[0] + suffix
os.path.splitext(image_abs_path)[0]
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str):
@@ -103,34 +102,17 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "vlm_embed"):
return False
if "vlm_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "byt5_embed"):
return False
if "byt5_mask" not in keys:
return False
if "ocr_mask" not in keys:
return False
else:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
if "vlm_mask" not in npz:
return False
if "byt5_embed" not in npz:
return False
if "byt5_mask" not in npz:
return False
if "ocr_mask" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -138,19 +120,6 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy()
vlm_mask = f.get_tensor("vlm_mask").numpy()
byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy()
byt5_mask = f.get_tensor("byt5_mask").numpy()
ocr_mask = f.get_tensor("ocr_mask").numpy()
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
data = np.load(npz_path)
vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"]
@@ -171,102 +140,54 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
tokenize_strategy, models, tokens_and_masks
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos)
else:
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
byt5_embed = byt5_embed.float()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
vlm_embed = vlm_embed.cpu()
vlm_mask = vlm_mask.cpu()
byt5_embed = byt5_embed.cpu()
byt5_mask = byt5_mask.cpu()
ocr_mask = ocr_mask.cpu()
vlm_embed = vlm_embed.cpu().numpy()
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
vlm_mask_i = vlm_mask[i]
byt5_embed_i = byt5_embed[i]
byt5_mask_i = byt5_mask[i]
ocr_mask_i = ocr_mask[i]
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
ve = vlm_embed[i]
be = byt5_embed[i]
tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve
tensors["vlm_mask"] = vlm_mask[i]
tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be
tensors["byt5_mask"] = byt5_mask[i]
tensors["ocr_mask"] = ocr_mask[i]
metadata = {
"architecture": "hunyuan_image",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (
vlm_embed[i].numpy(),
vlm_mask[i].numpy(),
byt5_embed[i].numpy(),
byt5_mask[i].numpy(),
ocr_mask[i].numpy(),
np.savez(
info.text_encoder_outputs_npz,
vlm_embed=vlm_embed_i,
vlm_mask=vlm_mask_i,
byt5_embed=byt5_embed_i,
byt5_mask=byt5_mask_i,
ocr_mask=ocr_mask_i,
)
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.cache_suffix
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
)
def _get_architecture_name(self) -> str:
return "hunyuan_image"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -146,7 +146,6 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
def __init__(
self,
@@ -163,10 +162,19 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -175,26 +183,13 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state"):
return False
if "attention_mask" not in keys:
return False
if "input_ids" not in keys:
return False
else:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -203,22 +198,11 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz/safetensors file
Load outputs from a npz file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
attention_mask = f.get_tensor("attention_mask").numpy()
input_ids = f.get_tensor("input_ids").numpy()
return [hidden_state, input_ids, attention_mask]
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
@@ -233,6 +217,16 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
@@ -258,75 +252,37 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
else:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy()
input_ids_np = input_ids.cpu().numpy()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids_np[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state = hidden_state.cpu()
input_ids = input_ids.cpu()
attention_mask = attention_masks.cpu()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs = hidden_state[i]
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
tensors["attention_mask"] = attention_mask[i]
tensors["input_ids"] = input_ids[i]
metadata = {
"architecture": "lumina",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state[i].numpy(),
input_ids[i].numpy(),
attention_mask[i].numpy(),
hidden_state_i,
input_ids_i,
attention_mask_i,
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
@@ -335,7 +291,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
@property
def cache_suffix(self) -> str:
return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
@@ -343,12 +299,9 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.cache_suffix
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
)
def _get_architecture_name(self) -> str:
return "lumina"
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],

View File

@@ -138,32 +138,24 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
SD_LATENTS_ST_SUFFIX = "_sd.safetensors"
SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors"
def __init__(
self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
if self.cache_format == "safetensors":
return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX
else:
return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
if self.cache_format != "safetensors":
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "sd" if self.sd else "sdxl"
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -255,7 +255,6 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors"
def __init__(
self,
@@ -271,8 +270,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -283,54 +281,27 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "lg_out"):
return False
if not _find_tensor_by_prefix(keys, "lg_pooled"):
return False
if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_lg_attn_mask" not in keys:
return False
apply_lg = f.get_tensor("apply_lg_attn_mask").item()
if bool(apply_lg) != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz:
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -338,20 +309,6 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy()
lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy()
g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
@@ -382,127 +339,65 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
enable_dropout=False,
)
l_attn_mask_tokens = tokens_and_masks[3]
g_attn_mask_tokens = tokens_and_masks[4]
t5_attn_mask_tokens = tokens_and_masks[5]
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(
lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
)
else:
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = l_attn_mask_tokens.cpu().numpy()
g_attn_mask = g_attn_mask_tokens.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(
self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
lg_out = lg_out.cpu()
t5_out = t5_out.cpu()
lg_pooled = lg_pooled.cpu()
l_attn_mask = l_attn_mask_tokens.cpu()
g_attn_mask = g_attn_mask_tokens.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i
tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i
tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i
tensors["clip_l_attn_mask"] = l_attn_mask[i]
tensors["clip_g_attn_mask"] = g_attn_mask[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool)
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "sd3",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (
lg_out[i].numpy(),
t5_out[i].numpy(),
lg_pooled[i].numpy(),
l_attn_mask[i].numpy(),
g_attn_mask[i].numpy(),
t5_attn_mask[i].numpy(),
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.cache_suffix
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def _get_architecture_name(self) -> str:
return "sd3"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -221,7 +221,6 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors"
def __init__(
self,
@@ -234,8 +233,7 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -246,22 +244,9 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state1"):
return False
if not _find_tensor_by_prefix(keys, "hidden_state2"):
return False
if not _find_tensor_by_prefix(keys, "pool2"):
return False
else:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -269,17 +254,6 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy()
hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy()
pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy()
return [hidden_state1, hidden_state2, pool2]
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
@@ -305,68 +279,28 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos)
else:
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
tensors = {}
# merge existing file if partial
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs1 = hidden_state1[i]
hs2 = hidden_state2[i]
p2 = pool2[i]
tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1
tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2
tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2
metadata = {
"architecture": "sdxl",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [
hidden_state1[i].numpy(),
hidden_state2[i].numpy(),
pool2[i].numpy(),
]
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

View File

@@ -4472,10 +4472,7 @@ def verify_training_args(args: argparse.Namespace):
Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
"""
from library.strategy_base import set_cache_format
enable_high_vram(args)
set_cache_format(args.cache_format)
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -4641,14 +4638,6 @@ def add_dataset_arguments(
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--cache_format",
type=str,
default="npz",
choices=["npz", "safetensors"],
help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O"
" / latentおよびtext encoder出力キャッシュの保存形式デフォルト: npz。safetensorsはネイティブdtype例: bf16で保存し、ファイルサイズ削減と高速化が可能",
)
parser.add_argument(
"--enable_bucket",
action="store_true",

View File

@@ -69,8 +69,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
strategy_base.set_cache_format(args.cache_format)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else:

View File

@@ -156,8 +156,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
text_encoder.eval()
# build text encoder outputs caching strategy
strategy_base.set_cache_format(args.cache_format)
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions