mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix apply_t5_attn_mask to work
This commit is contained in:
@@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_t5_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
# supports single model inference only
|
||||
# supports single model inference
|
||||
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
clip_l, t5xxl = models
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
@@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask
|
||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||
)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user