Remove redundant argument apply_t5_attn_mask

This commit is contained in:
Duoong
2026-02-07 18:00:20 +07:00
parent f41a9f02e4
commit 96a3ae2f87
6 changed files with 1 additions and 27 deletions

View File

@@ -165,7 +165,6 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
args.apply_t5_attn_mask,
)
)
train_dataset_group.set_current_strategies()
@@ -223,7 +222,6 @@ def train(args):
# Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
apply_t5_attn_mask=args.apply_t5_attn_mask,
dropout_rate=caption_dropout_rate,
)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
@@ -243,7 +241,6 @@ def train(args):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

View File

@@ -166,7 +166,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
apply_t5_attn_mask=args.apply_t5_attn_mask,
dropout_rate=caption_dropout_rate,
)
return self.text_encoding_strategy
@@ -193,7 +192,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
return None
@@ -471,7 +469,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
def update_metadata(self, metadata, args):
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')

View File

@@ -169,8 +169,6 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
- Maximum token length for the Qwen3 tokenizer. Default `512`.
* `--t5_max_token_length=<integer>`
- Maximum token length for the T5 tokenizer. Default `512`.
* `--apply_t5_attn_mask`
- Apply attention mask to T5 tokens in the LLM adapter.
* `--flash_attn`
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
* `--transformer_dtype=<choice>`
@@ -229,7 +227,6 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`
* `--apply_t5_attn_mask` - LLM AdapterでT5トークンにアテンションマスクを適用。
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
* `--transformer_dtype` - Transformerブロック用の個別dtype。
@@ -537,7 +534,6 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
The following Anima-specific metadata is saved in the LoRA model file:
* `ss_apply_t5_attn_mask`
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`
@@ -552,7 +548,6 @@ The following Anima-specific metadata is saved in the LoRA model file:
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます
* `ss_apply_t5_attn_mask`
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`

View File

@@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="Apply attention mask to T5 tokens in LLM adapter",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,

View File

@@ -86,10 +86,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
def __init__(
self,
apply_t5_attn_mask: bool = False,
dropout_rate: float = 0.0,
) -> None:
self.apply_t5_attn_mask = apply_t5_attn_mask
self.dropout_rate = dropout_rate
def encode_tokens(
@@ -97,7 +95,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -109,8 +106,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
@@ -222,10 +217,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
@@ -279,7 +272,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy,
models,
tokens_and_masks,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)

View File

@@ -200,7 +200,6 @@ def test_text_encoder_cache(args, pairs):
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(
apply_t5_attn_mask=False,
dropout_rate=0.0,
)
@@ -355,7 +354,6 @@ def test_text_encoder_cache(args, pairs):
# Test drop_cached_text_encoder_outputs
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
dropout_strategy = AnimaTextEncodingStrategy(
apply_t5_attn_mask=False,
dropout_rate=0.5, # high rate to ensure some drops
)
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
@@ -401,7 +399,7 @@ def test_full_batch_simulation(args, pairs):
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(apply_t5_attn_mask=False, dropout_rate=0.0)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
captions = [cap for _, cap in pairs]