feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

This commit is contained in:
kohya-ss
2026-02-09 12:43:11 +09:00
parent a1e3d02259
commit 2774e7757b

View File

@@ -1191,6 +1191,10 @@ class MiniTrainDIT(nn.Module):
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
# def set_flash_attn(self, use_flash_attn: bool):
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
@@ -1517,6 +1521,7 @@ class LLMAdapterTransformerBlock(nn.Module):
position_embeddings_context=None,
):
if self.has_self_attn:
# Self-attention: target_attention_mask is not expected to be all zeros
normed = self.norm_self_attn(x)
attn_out = self.self_attn(
normed,
@@ -1526,15 +1531,33 @@ class LLMAdapterTransformerBlock(nn.Module):
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
if source_attention_mask is not None:
# Select batch elements where source_attention_mask has at least one True value
batch_indices = torch.where(source_attention_mask.any(dim=(1, 2, 3)))[0]
# print("Batch indices for cross-attention:", batch_indices)
if len(batch_indices) == 0:
pass # No valid batch elements, skip cross-attention
else:
normed = self.norm_cross_attn(x[batch_indices])
attn_out = self.cross_attn(
normed,
mask=None,
context=context[batch_indices],
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x[batch_indices] = x[batch_indices] + attn_out
else:
# Standard cross-attention without masking
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
@@ -1619,27 +1642,53 @@ class Anima(nn.Module):
def dtype(self):
return self.net.dtype
def enable_gradient_checkpointing(self, *args, **kwargs):
self.net.enable_gradient_checkpointing(*args, **kwargs)
def disable_gradient_checkpointing(self):
self.net.disable_gradient_checkpointing()
def enable_block_swap(self, *args, **kwargs):
self.net.enable_block_swap(*args, **kwargs)
def move_to_device_except_swap_blocks(self, *args, **kwargs):
self.net.move_to_device_except_swap_blocks(*args, **kwargs)
def prepare_block_swap_before_forward(self, *args, **kwargs):
self.net.prepare_block_swap_before_forward(*args, **kwargs)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def preprocess_text_embeds(
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
return self.net.llm_adapter(
# print(
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
# )
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.net.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
else:
return source_hidden_states