mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix apply_t5_attn_mask to work
This commit is contained in:
@@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||
|
||||
Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.
|
||||
|
||||
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
||||
|
||||
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
||||
|
||||
@@ -67,14 +67,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return latents_caching_strategy
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_flux.FluxTextEncodingStrategy()
|
||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_t5_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
# supports single model inference only
|
||||
# supports single model inference
|
||||
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
clip_l, t5xxl = models
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
@@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask
|
||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||
)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user