mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock
This commit is contained in:
@@ -1191,6 +1191,10 @@ class MiniTrainDIT(nn.Module):
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
# def set_flash_attn(self, use_flash_attn: bool):
|
||||
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
|
||||
|
||||
@@ -1517,6 +1521,7 @@ class LLMAdapterTransformerBlock(nn.Module):
|
||||
position_embeddings_context=None,
|
||||
):
|
||||
if self.has_self_attn:
|
||||
# Self-attention: target_attention_mask is not expected to be all zeros
|
||||
normed = self.norm_self_attn(x)
|
||||
attn_out = self.self_attn(
|
||||
normed,
|
||||
@@ -1526,15 +1531,33 @@ class LLMAdapterTransformerBlock(nn.Module):
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=source_attention_mask,
|
||||
context=context,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x = x + attn_out
|
||||
if source_attention_mask is not None:
|
||||
# Select batch elements where source_attention_mask has at least one True value
|
||||
batch_indices = torch.where(source_attention_mask.any(dim=(1, 2, 3)))[0]
|
||||
# print("Batch indices for cross-attention:", batch_indices)
|
||||
if len(batch_indices) == 0:
|
||||
pass # No valid batch elements, skip cross-attention
|
||||
else:
|
||||
normed = self.norm_cross_attn(x[batch_indices])
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=None,
|
||||
context=context[batch_indices],
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x[batch_indices] = x[batch_indices] + attn_out
|
||||
else:
|
||||
# Standard cross-attention without masking
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=source_attention_mask,
|
||||
context=context,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
return x
|
||||
@@ -1619,27 +1642,53 @@ class Anima(nn.Module):
|
||||
def dtype(self):
|
||||
return self.net.dtype
|
||||
|
||||
def enable_gradient_checkpointing(self, *args, **kwargs):
|
||||
self.net.enable_gradient_checkpointing(*args, **kwargs)
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.net.disable_gradient_checkpointing()
|
||||
|
||||
def enable_block_swap(self, *args, **kwargs):
|
||||
self.net.enable_block_swap(*args, **kwargs)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, *args, **kwargs):
|
||||
self.net.move_to_device_except_swap_blocks(*args, **kwargs)
|
||||
|
||||
def prepare_block_swap_before_forward(self, *args, **kwargs):
|
||||
self.net.prepare_block_swap_before_forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
target_input_ids: Optional[torch.Tensor] = None,
|
||||
target_attention_mask: Optional[torch.Tensor] = None,
|
||||
source_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
|
||||
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
|
||||
|
||||
def preprocess_text_embeds(
|
||||
def _preprocess_text_embeds(
|
||||
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
|
||||
):
|
||||
if target_input_ids is not None:
|
||||
return self.net.llm_adapter(
|
||||
# print(
|
||||
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
|
||||
# )
|
||||
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
|
||||
context = self.net.llm_adapter(
|
||||
source_hidden_states,
|
||||
target_input_ids,
|
||||
target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask,
|
||||
)
|
||||
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
|
||||
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
|
||||
return context
|
||||
else:
|
||||
return source_hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user