mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Remove redundant argument apply_t5_attn_mask
This commit is contained in:
@@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
default=512,
|
||||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="Apply attention mask to T5 tokens in LLM adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
|
||||
@@ -86,10 +86,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def encode_tokens(
|
||||
@@ -97,7 +95,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
enable_dropout: bool = True,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
@@ -109,8 +106,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
@@ -222,10 +217,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
@@ -279,7 +272,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user