Text Encoder cache (WIP)

This commit is contained in:
Kohya S
2024-11-27 12:57:04 +09:00
parent bdac55ebbc
commit 3677094256
15 changed files with 628 additions and 471 deletions

View File

@@ -5,9 +5,6 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -15,6 +12,8 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_utils, train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
KEYS = ["l_pooled", "t5_out", "txt_ids"]
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
super().__init__(
FluxLatentsCachingStrategy.ARCHITECTURE,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked,
is_partial,
)
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
):
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
)[0]
else:
t5_attn_mask = None
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
@@ -154,44 +145,38 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = tokens_and_masks[2].cpu()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, info in enumerate(infos):
for i, (info, caption_index, caption) in enumerate(batch):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
if self.masked:
outputs += [t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
@@ -215,8 +200,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype