fix apply_t5_attn_mask to work

This commit is contained in:
Kohya S
2024-08-11 19:07:07 +09:00
parent 82314ac2e7
commit d25ae361d0
3 changed files with 19 additions and 7 deletions

View File

@@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.

View File

@@ -67,14 +67,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return latents_caching_strategy return latents_caching_strategy
def get_text_encoding_strategy(self, args): def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy() return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
def get_text_encoder_outputs_caching_strategy(self, args): def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
)
else: else:
return None return None

View File

@@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy):
class FluxTextEncodingStrategy(TextEncodingStrategy): class FluxTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None: def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
pass """
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens( def encode_tokens(
self, self,
tokenize_strategy: TokenizeStrategy, tokenize_strategy: TokenizeStrategy,
models: List[Any], models: List[Any],
tokens: List[torch.Tensor], tokens: List[torch.Tensor],
apply_t5_attn_mask: bool = False, apply_t5_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
# supports single model inference only # supports single model inference
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models clip_l, t5xxl = models
l_tokens, t5_tokens = tokens[:2] l_tokens, t5_tokens = tokens[:2]
@@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions) tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad(): with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
) )
if l_pooled.dtype == torch.bfloat16: if l_pooled.dtype == torch.bfloat16: