Files
Kohya-ss-sd-scripts/library/strategy_lumina.py
Kohya S a437949d47 feat: Add support for Safetensors format in caching strategies (WIP)
- Introduced Safetensors output format for various caching strategies including Hunyuan, Lumina, SD, SDXL, and SD3.
- Updated methods to handle loading and saving of tensors in Safetensors format.
- Enhanced output validation to check for required tensors in both NPZ and Safetensors formats.
- Modified dataset argument parser to include `--cache_format` option for selecting between NPZ and Safetensors formats.
- Updated caching logic to accommodate partial loading and merging of existing Safetensors files.
2026-03-22 21:15:12 +09:00

423 lines
15 KiB
Python

import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
TextEncoderOutputsCachingStrategy,
)
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GEMMA_ID = "google/gemma-2-2b"
class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"
if system_prompt is None:
system_prompt = ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
self.system_prompt = system_prompt
if max_length is None:
self.max_length = 256
else:
self.max_length = max_length
def tokenize(
self, text: Union[str, List[str]], is_negative: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor]:
token input ids, attention_masks
"""
text = [text] if isinstance(text, str) else text
# In training, we always add system prompt (is_negative=False)
if not is_negative:
# Add system prompt to the beginning of each text
text = [self.system_prompt + t for t in text]
encodings = self.tokenizer(
text,
max_length=self.max_length,
return_tensors="pt",
padding="max_length",
truncation=True,
pad_to_multiple_of=8,
)
return (encodings.input_ids, encodings.attention_mask)
def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
token input ids, attention_masks, weights
"""
# Gemma doesn't support weighted prompts, return uniform weights
tokens, attention_masks = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
return tokens, attention_masks, weights
class LuminaTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
text_encoder = models[0]
# Check model or torch dynamo OptimizedModule
assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
input_ids, attention_masks = tokens
outputs = text_encoder(
input_ids=input_ids.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device),
output_hidden_states=True,
return_dict=True,
)
return outputs.hidden_states[-2], input_ids, attention_masks
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
weights: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
weights_list (List[torch.Tensor]): Currently unused
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens)
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
) -> None:
super().__init__(
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state"):
return False
if "attention_mask" not in keys:
return False
if "input_ids" not in keys:
return False
else:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz/safetensors file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
attention_mask = f.get_tensor("attention_mask").numpy()
input_ids = f.get_tensor("input_ids").numpy()
return [hidden_state, input_ids, attention_mask]
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
input_ids = data["input_ids"]
return [hidden_state, input_ids, attention_mask]
@torch.no_grad()
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
captions = [info.caption for info in batch]
if self.is_weighted:
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
else:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy()
input_ids_np = input_ids.cpu().numpy()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids_np[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state = hidden_state.cpu()
input_ids = input_ids.cpu()
attention_mask = attention_masks.cpu()
for i, info in enumerate(batch):
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs = hidden_state[i]
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
tensors["attention_mask"] = attention_mask[i]
tensors["input_ids"] = input_ids[i]
metadata = {
"architecture": "lumina",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = [
hidden_state[i].numpy(),
input_ids[i].numpy(),
attention_mask[i].numpy(),
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "lumina"
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
) -> bool:
"""
Args:
bucket_reso (Tuple[int, int]): The resolution of the bucket.
npz_path (str): Path to the npz file.
flip_aug (bool): Whether to flip the image.
alpha_mask (bool): Whether to apply
"""
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]:
"""
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
"""
return self._default_load_latents_from_disk(
8, npz_path, bucket_reso
) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
model,
batch: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
vae_device = model.device
vae_dtype = model.dtype
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
batch,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True,
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(model.device)