Fix error in applying mask in Attention and add LoRA converter script

This commit is contained in:
Kohya S
2024-08-21 12:30:23 +09:00
parent e17c42cb0d
commit 2b07a92c8d
3 changed files with 10 additions and 3 deletions

View File

@@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 21, 2024 (update 2):
Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool.
Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA.
Aug 21, 2024:
The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed.

View File

@@ -708,9 +708,10 @@ class DoubleStreamBlock(nn.Module):
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
attn_mask = txt_attention_mask # b, seq_len
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
) # b, seq_len + img_len
# broadcast attn_mask to all heads

View File

@@ -248,7 +248,7 @@ def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
print(f"rank: {rank}, alpha: {alpha}, scale: {scale}")
# print(f"rank: {rank}, alpha: {alpha}, scale: {scale}")
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale