Files
Kohya-ss-sd-scripts/library/strategy_anima.py
Kohya S a437949d47 feat: Add support for Safetensors format in caching strategies (WIP)
- Introduced Safetensors output format for various caching strategies including Hunyuan, Lumina, SD, SDXL, and SD3.
- Updated methods to handle loading and saving of tensors in Safetensors format.
- Enhanced output validation to check for required tensors in both NPZ and Safetensors formats.
- Modified dataset argument parser to include `--cache_format` option for selecting between NPZ and Safetensors formats.
- Updated caching logic to accommodate partial loading and merging of existing Safetensors files.
2026-03-22 21:15:12 +09:00

382 lines
16 KiB
Python

# Anima Strategy Classes
import os
import random
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library import qwen_image_autoencoder_kl
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class AnimaTokenizeStrategy(TokenizeStrategy):
"""Tokenize strategy for Anima: dual tokenization with Qwen3 + T5.
Qwen3 tokens are used for the text encoder.
T5 tokens are used as target input IDs for the LLM Adapter (NOT encoded by T5).
Can be initialized with either pre-loaded tokenizer objects or paths to load from.
"""
def __init__(
self,
qwen3_tokenizer=None,
t5_tokenizer=None,
qwen3_max_length: int = 512,
t5_max_length: int = 512,
qwen3_path: Optional[str] = None,
t5_tokenizer_path: Optional[str] = None,
) -> None:
# Load tokenizers from paths if not provided directly
if qwen3_tokenizer is None:
if qwen3_path is None:
raise ValueError("Either qwen3_tokenizer or qwen3_path must be provided")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(qwen3_path)
if t5_tokenizer is None:
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
class AnimaTextEncodingStrategy(TextEncodingStrategy):
"""Text encoding strategy for Anima.
Encodes Qwen3 tokens through the Qwen3 text encoder to get hidden states.
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
Args:
models: [qwen3_text_encoder]
tokens: [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
encoder_device = qwen3_text_encoder.device
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
prompt_embeds[~qwen3_attn_mask.bool()] = 0
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
prompt_embeds: torch.Tensor,
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
Called during training when using cached outputs.
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = 0
if attn_mask is not None:
attn_mask[i] = 0
if t5_input_ids is not None:
t5_input_ids[i, 0] = 1 # Set to </s> token ID
t5_input_ids[i, 1:] = 0
if t5_attn_mask is not None:
t5_attn_mask[i, 0] = 1
t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
"""Caching strategy for Anima text encoder outputs.
Caches: prompt_embeds (float), attn_mask (int), t5_input_ids (int), t5_attn_mask (int)
"""
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "prompt_embeds"):
return False
if "attn_mask" not in keys:
return False
if "t5_input_ids" not in keys:
return False
if "t5_attn_mask" not in keys:
return False
if "caption_dropout_rate" not in keys:
return False
else:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
attn_mask = f.get_tensor("attn_mask").numpy()
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
infos: List,
):
anima_text_encoding_strategy: AnimaTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
else:
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
prompt_embeds = prompt_embeds.cpu()
attn_mask = attn_mask.cpu()
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pe = prompt_embeds[i]
tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
tensors["attn_mask"] = attn_mask[i]
tensors["t5_input_ids"] = t5_input_ids[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
metadata = {
"architecture": "anima",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
info.text_encoder_outputs = (
prompt_embeds[i].numpy(),
attn_mask[i].numpy(),
t5_input_ids[i].numpy(),
t5_attn_mask[i].numpy(),
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""Latent caching strategy for Anima using WanVAE.
WanVAE produces 16-channel latents with spatial downscale 8x.
Latent shape for images: (B, 16, 1, H/8, W/8)
"""
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "anima"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
"""Cache batch of latents using Qwen Image VAE.
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization.
"""
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
vae_device = vae.device
vae_dtype = vae.dtype
def encode_by_vae(img_tensor):
"""Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
return latents.to("cpu")
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae_device)