Files
Kohya-ss-sd-scripts/library/strategy_base.py
Kohya S a437949d47 feat: Add support for Safetensors format in caching strategies (WIP)
- Introduced Safetensors output format for various caching strategies including Hunyuan, Lumina, SD, SDXL, and SD3.
- Updated methods to handle loading and saving of tensors in Safetensors format.
- Enhanced output validation to check for required tensors in both NPZ and Safetensors formats.
- Modified dataset argument parser to include `--cache_format` option for selecting between NPZ and Safetensors formats.
- Updated caching logic to accommodate partial loading and merging of existing Safetensors files.
2026-03-22 21:15:12 +09:00

839 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# base class for platform strategies. this file defines the interface for strategies
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
# global cache format setting: "npz" or "safetensors"
_cache_format: str = "npz"
def set_cache_format(cache_format: str) -> None:
global _cache_format
_cache_format = cache_format
def get_cache_format() -> str:
return _cache_format
_TORCH_DTYPE_TO_STR = {
torch.float64: "float64",
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint8: "uint8",
torch.bool: "bool",
}
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
def _dtype_to_str(dtype: torch.dtype) -> str:
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
for key in tensors_keys:
if key.startswith(prefix) or key == prefix:
return key
return None
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
_re_attention = re.compile(
r"""\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError
def _get_weighted_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
max_length includes starting and ending tokens.
"""
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in TokenizeStrategy._re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(text: str, max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
No padding, starting or ending token is included.
"""
truncated = False
texts_and_weights = parse_prompt_attention(text)
tokens = []
weights = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
tokens += token
# copy the weight by length of token
weights += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(tokens) > max_length:
truncated = True
break
# truncate
if len(tokens) > max_length:
truncated = True
tokens = tokens[:max_length]
weights = weights[:max_length]
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
return tokens, weights
if max_length is None:
max_length = tokenizer.model_max_length
tokens, weights = get_prompts_with_weights(text, max_length - 2)
tokens, weights = pad_tokens_and_weights(
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
)
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
def _get_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
if weighted:
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
else:
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
if weighted:
weights = weights.squeeze(0)
new_weights = torch.ones(input_ids.shape)
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
b = i // (tokenizer.model_max_length - 2)
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
weights = new_weights
if weighted:
return input_ids, weights
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:param weights: list of weight tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
self._is_weighted = is_weighted
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def is_partial(self):
return self._is_partial
@property
def is_weighted(self):
return self._is_weighted
@property
def cache_format(self) -> str:
return get_cache_format()
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def cache_format(self) -> str:
return get_cache_format()
@property
def cache_suffix(self):
raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
"""
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz/safetensors file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
Returns:
bool
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
if npz_path.endswith(".safetensors"):
return self._is_disk_cached_latents_expected_safetensors(
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
# In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
# In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
if "latents" + key_reso_suffix not in npz and "latents" not in npz:
return False
if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
return False
if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def _is_disk_cached_latents_expected_safetensors(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
st_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
from library.safetensors_utils import MemoryEfficientSafeOpen
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
try:
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_prefix = f"latents_{reso_tag}"
if not any(k.startswith(latents_prefix) for k in keys):
return False
if flip_aug:
flipped_prefix = f"latents_flipped_{reso_tag}"
if not any(k.startswith(flipped_prefix) for k in keys):
return False
if apply_alpha_mask:
mask_prefix = f"alpha_mask_{reso_tag}"
if not any(k.startswith(mask_prefix) for k in keys):
return False
except Exception as e:
logger.error(f"Error loading file: {st_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae: Callable,
vae_device: torch.device,
vae_dtype: torch.dtype,
image_infos: List,
flip_aug: bool,
apply_alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
Args:
encode_by_vae: function to encode images by VAE
vae_device: device to use for VAE
vae_dtype: dtype to use for VAE
image_infos: list of ImageInfo
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
random_crop: whether to random crop images
multi_resolution: whether to use multi-resolution latents
Returns:
None
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, apply_alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[-2:] # H, W (supports both 4D and 5D latents)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz/safetensors file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if npz_path.endswith(".safetensors"):
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
if latents_stride is None:
key_reso_suffix = ""
else:
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
# Fallback to old npz without resolution suffix
if "latents" not in npz:
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
if not self._warned_fallback_to_old_npz:
logger.warning(
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
)
self._warned_fallback_to_old_npz = True
key_reso_suffix = ""
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def _load_latents_from_disk_safetensors(
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
from library.safetensors_utils import MemoryEfficientSafeOpen
if latents_stride is None:
reso_tag = "1x"
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
if latents_key is None:
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
latents = f.get_tensor(latents_key).numpy()
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
"""
Args:
npz_path (str): Path to the npz/safetensors file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
Returns:
None
"""
if npz_path.endswith(".safetensors"):
self._save_latents_to_disk_safetensors(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
)
return
kwargs = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
# float() is needed because npz doesn't support bfloat16
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
def _save_latents_to_disk_safetensors(
self,
st_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
latents_tensor = latents_tensor.cpu()
latents_size = latents_tensor.shape[-2:] # H, W
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
dtype_str = _dtype_to_str(latents_tensor.dtype)
# NaN check and zero replacement
if torch.isnan(latents_tensor).any():
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
tensors: Dict[str, torch.Tensor] = {}
# load existing file and merge (for multi-resolution)
if os.path.exists(st_path):
with MemoryEfficientSafeOpen(st_path) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
if flipped_latents_tensor is not None:
flipped_latents_tensor = flipped_latents_tensor.cpu()
if torch.isnan(flipped_latents_tensor).any():
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
if alpha_mask is not None:
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
metadata = {
"architecture": self._get_architecture_name(),
"width": str(latents_size[1]),
"height": str(latents_size[0]),
"format_version": LATENTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, st_path, metadata=metadata)
def _get_architecture_name(self) -> str:
"""Override in subclasses to return the architecture name for safetensors metadata."""
return "unknown"