mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Text Encoder cache (WIP)
This commit is contained in:
@@ -5,9 +5,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -15,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_utils, train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
KEYS = ["l_pooled", "t5_out", "txt_ids"]
|
||||
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
super().__init__(
|
||||
FluxLatentsCachingStrategy.ARCHITECTURE,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
):
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)[0]
|
||||
else:
|
||||
t5_attn_mask = None
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
@@ -154,44 +145,38 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
l_pooled = l_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
txt_ids = txt_ids.cpu()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
|
||||
if self.masked:
|
||||
outputs += [t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -215,8 +200,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
Reference in New Issue
Block a user