Fix to remove zero pad for t5xxl output

This commit is contained in:
Kohya S
2024-08-22 19:56:27 +09:00
parent a4d27a232b
commit 2d8fa3387a
2 changed files with 16 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
@@ -120,25 +120,24 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask:
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
@@ -149,10 +148,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
@@ -171,6 +168,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
@@ -179,6 +177,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)