mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Remove redundant argument apply_t5_attn_mask
This commit is contained in:
@@ -165,7 +165,6 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -223,7 +222,6 @@ def train(args):
|
||||
# Set text encoding strategy
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
@@ -243,7 +241,6 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
|
||||
@@ -166,7 +166,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoding_strategy(self, args):
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
return self.text_encoding_strategy
|
||||
@@ -193,7 +192,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -471,7 +469,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||
|
||||
@@ -169,8 +169,6 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||
* `--t5_max_token_length=<integer>`
|
||||
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||
* `--apply_t5_attn_mask`
|
||||
- Apply attention mask to T5 tokens in the LLM adapter.
|
||||
* `--flash_attn`
|
||||
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
|
||||
* `--transformer_dtype=<choice>`
|
||||
@@ -229,7 +227,6 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
|
||||
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--apply_t5_attn_mask` - LLM AdapterでT5トークンにアテンションマスクを適用。
|
||||
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
|
||||
* `--transformer_dtype` - Transformerブロック用の個別dtype。
|
||||
|
||||
@@ -537,7 +534,6 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
|
||||
|
||||
The following Anima-specific metadata is saved in the LoRA model file:
|
||||
|
||||
* `ss_apply_t5_attn_mask`
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
@@ -552,7 +548,6 @@ The following Anima-specific metadata is saved in the LoRA model file:
|
||||
|
||||
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:
|
||||
|
||||
* `ss_apply_t5_attn_mask`
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
|
||||
@@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
default=512,
|
||||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="Apply attention mask to T5 tokens in LLM adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
|
||||
@@ -86,10 +86,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def encode_tokens(
|
||||
@@ -97,7 +95,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
enable_dropout: bool = True,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
@@ -109,8 +106,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
@@ -222,10 +217,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
@@ -279,7 +272,6 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -200,7 +200,6 @@ def test_text_encoder_cache(args, pairs):
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(
|
||||
apply_t5_attn_mask=False,
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
|
||||
@@ -355,7 +354,6 @@ def test_text_encoder_cache(args, pairs):
|
||||
# Test drop_cached_text_encoder_outputs
|
||||
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
|
||||
dropout_strategy = AnimaTextEncodingStrategy(
|
||||
apply_t5_attn_mask=False,
|
||||
dropout_rate=0.5, # high rate to ensure some drops
|
||||
)
|
||||
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
|
||||
@@ -401,7 +399,7 @@ def test_full_batch_simulation(args, pairs):
|
||||
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(apply_t5_attn_mask=False, dropout_rate=0.0)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user