From 2774e7757bb037907b2a3d7fa5db8199a9c30c46 Mon Sep 17 00:00:00 2001 From: kohya-ss <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 9 Feb 2026 12:43:11 +0900 Subject: [PATCH] feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock --- library/anima_models.py | 73 ++++++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/library/anima_models.py b/library/anima_models.py index d3adff5f..d7e800c9 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -1191,6 +1191,10 @@ class MiniTrainDIT(nn.Module): def device(self): return next(self.parameters()).device + @property + def dtype(self): + return next(self.parameters()).dtype + # def set_flash_attn(self, use_flash_attn: bool): # """Toggle flash attention for all DiT blocks (self-attn + cross-attn). @@ -1517,6 +1521,7 @@ class LLMAdapterTransformerBlock(nn.Module): position_embeddings_context=None, ): if self.has_self_attn: + # Self-attention: target_attention_mask is not expected to be all zeros normed = self.norm_self_attn(x) attn_out = self.self_attn( normed, @@ -1526,15 +1531,33 @@ class LLMAdapterTransformerBlock(nn.Module): ) x = x + attn_out - normed = self.norm_cross_attn(x) - attn_out = self.cross_attn( - normed, - mask=source_attention_mask, - context=context, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings_context, - ) - x = x + attn_out + if source_attention_mask is not None: + # Select batch elements where source_attention_mask has at least one True value + batch_indices = torch.where(source_attention_mask.any(dim=(1, 2, 3)))[0] + # print("Batch indices for cross-attention:", batch_indices) + if len(batch_indices) == 0: + pass # No valid batch elements, skip cross-attention + else: + normed = self.norm_cross_attn(x[batch_indices]) + attn_out = self.cross_attn( + normed, + mask=None, + context=context[batch_indices], + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) + x[batch_indices] = x[batch_indices] + attn_out + else: + # Standard cross-attention without masking + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn( + normed, + mask=source_attention_mask, + context=context, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) + x = x + attn_out x = x + self.mlp(self.norm_mlp(x)) return x @@ -1619,27 +1642,53 @@ class Anima(nn.Module): def dtype(self): return self.net.dtype + def enable_gradient_checkpointing(self, *args, **kwargs): + self.net.enable_gradient_checkpointing(*args, **kwargs) + + def disable_gradient_checkpointing(self): + self.net.disable_gradient_checkpointing() + + def enable_block_swap(self, *args, **kwargs): + self.net.enable_block_swap(*args, **kwargs) + + def move_to_device_except_swap_blocks(self, *args, **kwargs): + self.net.move_to_device_except_swap_blocks(*args, **kwargs) + + def prepare_block_swap_before_forward(self, *args, **kwargs): + self.net.prepare_block_swap_before_forward(*args, **kwargs) + def forward( self, x: torch.Tensor, timesteps: torch.Tensor, - context: torch.Tensor, + context: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, + target_input_ids: Optional[torch.Tensor] = None, + target_attention_mask: Optional[torch.Tensor] = None, + source_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: + context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask) return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs) - def preprocess_text_embeds( + def _preprocess_text_embeds( self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None ): if target_input_ids is not None: - return self.net.llm_adapter( + # print( + # f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}" + # ) + # print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}") + context = self.net.llm_adapter( source_hidden_states, target_input_ids, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, ) + context[~target_attention_mask.bool()] = 0 # zero out padding tokens + # print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}") + return context else: return source_hidden_states