Add Flash, cuDNN, Efficient attention for Flux

This commit is contained in:
rockerBOO
2025-04-11 23:12:29 -04:00
parent 5a18a03ffc
commit 2ab5bc69e6

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
from library import custom_offloading_utils from library import custom_offloading_utils
@@ -445,11 +446,13 @@ configs = {
# region math # region math
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) with sdpa_kernel(kernels, set_priority=True):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)") x = rearrange(x, "B H L D -> B L (H D)")
return x return x