Fix error in applying mask in Attention and add LoRA converter script

This commit is contained in:
Kohya S
2024-08-21 12:30:23 +09:00
parent e17c42cb0d
commit 2b07a92c8d
3 changed files with 10 additions and 3 deletions

View File

@@ -708,9 +708,10 @@ class DoubleStreamBlock(nn.Module):
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
attn_mask = txt_attention_mask # b, seq_len
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
) # b, seq_len + img_len
# broadcast attn_mask to all heads