Fix issue with attention mask not being applied in single blocks

This commit is contained in:
Kohya S
2024-08-24 12:39:54 +09:00
parent 99744af53a
commit 2e89cd2cc6
3 changed files with 36 additions and 33 deletions

View File

@@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module):
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
# )
# else:
# return self._forward(img, txt, vec, pe)
class SingleStreamBlock(nn.Module):
"""
@@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module):
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(
attn_mask,
torch.ones(
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
),
),
dim=1,
) # b, seq_len + img_len = x_len
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
# cpu offload checkpointing
@@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(x, vec, pe)
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
# else:
# return self._forward(x, vec, pe)
return self._forward(x, vec, pe, txt_attention_mask)
class LastLayer(nn.Module):
@@ -1053,7 +1052,7 @@ class Flux(nn.Module):
if not self.single_blocks_to_swap:
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.single_blocks_to_swap):
@@ -1075,7 +1074,7 @@ class Flux(nn.Module):
block.to(self.device) # move to cuda
# print(f"Moved single block {block_idx} to cuda.")
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
@@ -1250,10 +1249,11 @@ class FluxLower(nn.Module):
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)