From 149917b1857fa007b6dba917d337363ac5db526b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 22 Apr 2025 13:51:22 -0400 Subject: [PATCH] Add SDPABackend for SD and SDXL --- library/original_unet.py | 5 ++++- library/sdxl_original_unet.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index e944ff22..bcde6a86 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -112,6 +112,7 @@ from typing import Dict, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel from einops import rearrange from library.utils import setup_logging setup_logging() @@ -560,6 +561,7 @@ class Downsample2D(nn.Module): return hidden_states +kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] class CrossAttention(nn.Module): def __init__( @@ -741,7 +743,8 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(kernels): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = rearrange(out, "b h n d -> b n (h d)", h=h) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a8..d6d6a2b3 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -29,6 +29,7 @@ import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel from einops import rearrange from .utils import setup_logging @@ -387,6 +388,7 @@ class Downsample2D(nn.Module): return hidden_states +kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] class CrossAttention(nn.Module): def __init__( @@ -545,7 +547,8 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + with sdpa_kernel(kernels): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = rearrange(out, "b h n d -> b n (h d)", h=h)