diff --git a/README.md b/README.md index 43edbbed..f4056851 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 2): +Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. + +Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. + + Aug 21, 2024: The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. diff --git a/library/flux_models.py b/library/flux_models.py index 6f28da60..e38119cd 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -708,9 +708,10 @@ class DoubleStreamBlock(nn.Module): # make attention mask if not None attn_mask = None if txt_attention_mask is not None: - attn_mask = txt_attention_mask # b, seq_len + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len attn_mask = torch.cat( - (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 ) # b, seq_len + img_len # broadcast attn_mask to all heads diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index dd962ebf..e9743534 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -248,7 +248,7 @@ def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 scale_down = scale