mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix apply_t5_attn_mask to work
This commit is contained in:
@@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||||
|
|
||||||
|
Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.
|
||||||
|
|
||||||
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
||||||
|
|
||||||
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
||||||
|
|||||||
@@ -67,14 +67,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return latents_caching_strategy
|
return latents_caching_strategy
|
||||||
|
|
||||||
def get_text_encoding_strategy(self, args):
|
def get_text_encoding_strategy(self, args):
|
||||||
return strategy_flux.FluxTextEncodingStrategy()
|
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||||
|
|
||||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
||||||
|
|
||||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy):
|
|||||||
|
|
||||||
|
|
||||||
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||||
def __init__(self) -> None:
|
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||||
pass
|
"""
|
||||||
|
Args:
|
||||||
|
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||||
|
"""
|
||||||
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||||
|
|
||||||
def encode_tokens(
|
def encode_tokens(
|
||||||
self,
|
self,
|
||||||
tokenize_strategy: TokenizeStrategy,
|
tokenize_strategy: TokenizeStrategy,
|
||||||
models: List[Any],
|
models: List[Any],
|
||||||
tokens: List[torch.Tensor],
|
tokens: List[torch.Tensor],
|
||||||
apply_t5_attn_mask: bool = False,
|
apply_t5_attn_mask: Optional[bool] = None,
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
# supports single model inference only
|
# supports single model inference
|
||||||
|
|
||||||
|
if apply_t5_attn_mask is None:
|
||||||
|
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||||
|
|
||||||
clip_l, t5xxl = models
|
clip_l, t5xxl = models
|
||||||
l_tokens, t5_tokens = tokens[:2]
|
l_tokens, t5_tokens = tokens[:2]
|
||||||
@@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
|
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||||
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask
|
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||||
)
|
)
|
||||||
|
|
||||||
if l_pooled.dtype == torch.bfloat16:
|
if l_pooled.dtype == torch.bfloat16:
|
||||||
|
|||||||
Reference in New Issue
Block a user