mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix issue with attention mask not being applied in single blocks
This commit is contained in:
@@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
Aug 24, 2024:
|
||||
Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified.
|
||||
|
||||
Aug 22, 2024 (update 2):
|
||||
Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option.
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
return self.flux_lower(img, txt, vec, pe)
|
||||
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -352,7 +352,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(
|
||||
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
|
||||
# )
|
||||
# else:
|
||||
# return self._forward(img, txt, vec, pe)
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
@@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# make attention mask if not None
|
||||
attn_mask = None
|
||||
if txt_attention_mask is not None:
|
||||
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
||||
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
||||
attn_mask = torch.cat(
|
||||
(
|
||||
attn_mask,
|
||||
torch.ones(
|
||||
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
) # b, seq_len + img_len = x_len
|
||||
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
|
||||
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
|
||||
# cpu offload checkpointing
|
||||
|
||||
@@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
|
||||
# else:
|
||||
# return self._forward(x, vec, pe)
|
||||
return self._forward(x, vec, pe, txt_attention_mask)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
@@ -1053,7 +1052,7 @@ class Flux(nn.Module):
|
||||
|
||||
if not self.single_blocks_to_swap:
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.single_blocks_to_swap):
|
||||
@@ -1075,7 +1074,7 @@ class Flux(nn.Module):
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
@@ -1250,10 +1249,11 @@ class FluxLower(nn.Module):
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
Reference in New Issue
Block a user