Fix issue with attention mask not being applied in single blocks

This commit is contained in:
Kohya S
2024-08-24 12:39:54 +09:00
parent 99744af53a
commit 2e89cd2cc6
3 changed files with 36 additions and 33 deletions

View File

@@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 24, 2024:
Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified.
Aug 22, 2024 (update 2):
Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option.

View File

@@ -243,7 +243,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
@@ -352,7 +352,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)

View File

@@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module):
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
# )
# else:
# return self._forward(img, txt, vec, pe)
class SingleStreamBlock(nn.Module):
"""
@@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module):
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(
attn_mask,
torch.ones(
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
),
),
dim=1,
) # b, seq_len + img_len = x_len
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
# cpu offload checkpointing
@@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(x, vec, pe)
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
# else:
# return self._forward(x, vec, pe)
return self._forward(x, vec, pe, txt_attention_mask)
class LastLayer(nn.Module):
@@ -1053,7 +1052,7 @@ class Flux(nn.Module):
if not self.single_blocks_to_swap:
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.single_blocks_to_swap):
@@ -1075,7 +1074,7 @@ class Flux(nn.Module):
block.to(self.device) # move to cuda
# print(f"Moved single block {block_idx} to cuda.")
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
@@ -1250,10 +1249,11 @@ class FluxLower(nn.Module):
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)