fix apply_t5_attn_mask to work

This commit is contained in:
Kohya S
2024-08-11 19:07:07 +09:00
parent 82314ac2e7
commit d25ae361d0
3 changed files with 19 additions and 7 deletions

View File

@@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy):
class FluxTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: bool = False,
apply_t5_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference only
# supports single model inference
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models
l_tokens, t5_tokens = tokens[:2]
@@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)
if l_pooled.dtype == torch.bfloat16: