diff --git a/library/sd3_models.py b/library/sd3_models.py index 840f9186..60356e82 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -645,7 +645,7 @@ class MMDiTBlock(nn.Module): if self.x_block.x_block_self_attn: x_q2, x_k2, x_v2 = x_qkv2 - attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) else: x = self.x_block.post_attention(x_attn_out, *x_intermediates)