diff --git a/library/original_unet.py b/library/original_unet.py index aa9dc233..28863ac1 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -112,6 +112,7 @@ from typing import Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel from einops import rearrange from library.utils import setup_logging @@ -564,6 +565,7 @@ class Downsample2D(nn.Module): return hidden_states +kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] class CrossAttention(nn.Module): def __init__( @@ -739,7 +741,8 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(kernels): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = rearrange(out, "b h n d -> b n (h d)", h=h) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 0aa07d0d..5d5c2a7c 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -29,6 +29,7 @@ import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel from einops import rearrange from library.utils import setup_logging @@ -387,6 +388,7 @@ class Downsample2D(nn.Module): return hidden_states +kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] class CrossAttention(nn.Module): def __init__( @@ -545,7 +547,8 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(kernels): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = rearrange(out, "b h n d -> b n (h d)", h=h)