From d25ae361d06bb6f49c104ca2e6b4a9188a88c95f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 19:07:07 +0900 Subject: [PATCH] fix apply_t5_attn_mask to work --- README.md | 2 ++ flux_train_network.py | 6 ++++-- library/strategy_flux.py | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d016bcec..d47776ca 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. + Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. diff --git a/flux_train_network.py b/flux_train_network.py index 69b6e8ea..59a666aa 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -67,14 +67,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy() + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + ) else: return None diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 13459d32..3880a1e1 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy): class FluxTextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_t5_attn_mask: bool = False, + apply_t5_attn_mask: Optional[bool] = None, ) -> List[torch.Tensor]: - # supports single model inference only + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask clip_l, t5xxl = models l_tokens, t5_tokens = tokens[:2] @@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # attn_mask is not applied when caching to disk: it is applied when loading from disk l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) if l_pooled.dtype == torch.bfloat16: