diff --git a/anima_train.py b/anima_train.py index 9661cd90..d703d02e 100644 --- a/anima_train.py +++ b/anima_train.py @@ -165,7 +165,6 @@ def train(args): args.text_encoder_batch_size, False, False, - args.apply_t5_attn_mask, ) ) train_dataset_group.set_current_strategies() @@ -223,7 +222,6 @@ def train(args): # Set text encoding strategy caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( - apply_t5_attn_mask=args.apply_t5_attn_mask, dropout_rate=caption_dropout_rate, ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) @@ -243,7 +241,6 @@ def train(args): args.text_encoder_batch_size, args.skip_cache_check, is_partial=False, - apply_t5_attn_mask=args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) diff --git a/anima_train_network.py b/anima_train_network.py index 5232470f..2416cd7b 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -166,7 +166,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def get_text_encoding_strategy(self, args): caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( - apply_t5_attn_mask=args.apply_t5_attn_mask, dropout_rate=caption_dropout_rate, ) return self.text_encoding_strategy @@ -193,7 +192,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): args.text_encoder_batch_size, args.skip_cache_check, is_partial=False, - apply_t5_attn_mask=args.apply_t5_attn_mask, ) return None @@ -471,7 +469,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True) def update_metadata(self, metadata, args): - metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal') diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md index 0e29a7a9..fe6b2354 100644 --- a/docs/anima_train_network.md +++ b/docs/anima_train_network.md @@ -169,8 +169,6 @@ Besides the arguments explained in the [train_network.py guide](train_network.md - Maximum token length for the Qwen3 tokenizer. Default `512`. * `--t5_max_token_length=` - Maximum token length for the T5 tokenizer. Default `512`. -* `--apply_t5_attn_mask` - - Apply attention mask to T5 tokens in the LLM adapter. * `--flash_attn` - Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks. * `--transformer_dtype=` @@ -229,7 +227,6 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon * `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。 * `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。 * `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。 -* `--apply_t5_attn_mask` - LLM AdapterでT5トークンにアテンションマスクを適用。 * `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。 * `--transformer_dtype` - Transformerブロック用の個別dtype。 @@ -537,7 +534,6 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー The following Anima-specific metadata is saved in the LoRA model file: -* `ss_apply_t5_attn_mask` * `ss_weighting_scheme` * `ss_discrete_flow_shift` * `ss_timestep_sample_method` @@ -552,7 +548,6 @@ The following Anima-specific metadata is saved in the LoRA model file: 以下のAnima固有のメタデータがLoRAモデルファイルに保存されます: -* `ss_apply_t5_attn_mask` * `ss_weighting_scheme` * `ss_discrete_flow_shift` * `ss_timestep_sample_method` diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index 85d433f0..d181eea1 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser): default=512, help="Maximum token length for T5 tokenizer (default: 512)", ) - parser.add_argument( - "--apply_t5_attn_mask", - action="store_true", - help="Apply attention mask to T5 tokens in LLM adapter", - ) parser.add_argument( "--discrete_flow_shift", type=float, diff --git a/library/strategy_anima.py b/library/strategy_anima.py index a545f9e9..a2375e88 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -86,10 +86,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): def __init__( self, - apply_t5_attn_mask: bool = False, dropout_rate: float = 0.0, ) -> None: - self.apply_t5_attn_mask = apply_t5_attn_mask self.dropout_rate = dropout_rate def encode_tokens( @@ -97,7 +95,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_t5_attn_mask: Optional[bool] = None, enable_dropout: bool = True, ) -> List[torch.Tensor]: """Encode Qwen3 tokens and return embeddings + T5 token IDs. @@ -109,8 +106,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): Returns: [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] """ - if apply_t5_attn_mask is None: - apply_t5_attn_mask = self.apply_t5_attn_mask qwen3_text_encoder = models[0] qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens @@ -222,10 +217,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False, - apply_t5_attn_mask: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) - self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -279,7 +272,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokenize_strategy, models, tokens_and_masks, - apply_t5_attn_mask=self.apply_t5_attn_mask, enable_dropout=False, ) diff --git a/tests/test_anima_cache.py b/tests/test_anima_cache.py index 8f9872e4..1684eb53 100644 --- a/tests/test_anima_cache.py +++ b/tests/test_anima_cache.py @@ -200,7 +200,6 @@ def test_text_encoder_cache(args, pairs): t5_max_length=args.t5_max_length, ) text_encoding_strategy = AnimaTextEncodingStrategy( - apply_t5_attn_mask=False, dropout_rate=0.0, ) @@ -355,7 +354,6 @@ def test_text_encoder_cache(args, pairs): # Test drop_cached_text_encoder_outputs print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...") dropout_strategy = AnimaTextEncodingStrategy( - apply_t5_attn_mask=False, dropout_rate=0.5, # high rate to ensure some drops ) dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked) @@ -401,7 +399,7 @@ def test_full_batch_simulation(args, pairs): qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer, qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length, ) - text_encoding_strategy = AnimaTextEncodingStrategy(apply_t5_attn_mask=False, dropout_rate=0.0) + text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0) captions = [cap for _, cap in pairs]