mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix issue with attention mask not being applied in single blocks
This commit is contained in:
@@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(
|
||||
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
|
||||
# )
|
||||
# else:
|
||||
# return self._forward(img, txt, vec, pe)
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
@@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# make attention mask if not None
|
||||
attn_mask = None
|
||||
if txt_attention_mask is not None:
|
||||
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
||||
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
||||
attn_mask = torch.cat(
|
||||
(
|
||||
attn_mask,
|
||||
torch.ones(
|
||||
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
) # b, seq_len + img_len = x_len
|
||||
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
|
||||
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
|
||||
# cpu offload checkpointing
|
||||
|
||||
@@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
|
||||
# else:
|
||||
# return self._forward(x, vec, pe)
|
||||
return self._forward(x, vec, pe, txt_attention_mask)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
@@ -1053,7 +1052,7 @@ class Flux(nn.Module):
|
||||
|
||||
if not self.single_blocks_to_swap:
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.single_blocks_to_swap):
|
||||
@@ -1075,7 +1074,7 @@ class Flux(nn.Module):
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
@@ -1250,10 +1249,11 @@ class FluxLower(nn.Module):
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
Reference in New Issue
Block a user