support sdpa

This commit is contained in:
ykume
2023-06-11 21:26:15 +09:00
parent 4d0c06e397
commit 9e1683cf2b
9 changed files with 177 additions and 84 deletions

View File

@@ -1788,7 +1788,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
@@ -1800,6 +1800,9 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers):
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_sdpa(True)
"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
@@ -2048,6 +2051,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)")
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)