Remove redundant argument apply_t5_attn_mask

This commit is contained in:
Duoong
2026-02-07 18:00:20 +07:00
parent f41a9f02e4
commit 96a3ae2f87
6 changed files with 1 additions and 27 deletions

View File

@@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="Apply attention mask to T5 tokens in LLM adapter",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,

View File

@@ -86,10 +86,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
def __init__(
self,
apply_t5_attn_mask: bool = False,
dropout_rate: float = 0.0,
) -> None:
self.apply_t5_attn_mask = apply_t5_attn_mask
self.dropout_rate = dropout_rate
def encode_tokens(
@@ -97,7 +95,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -109,8 +106,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
@@ -222,10 +217,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
@@ -279,7 +272,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy,
models,
tokens_and_masks,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)