Compare commits

...

3 Commits

Author SHA1 Message Date
Dave Lage
12ea9b2ec5 Merge 90d14b9eb0 into fa53f71ec0 2026-04-05 00:39:06 +00:00
rockerBOO
90d14b9eb0 Remove priority 2025-04-12 04:09:39 -04:00
rockerBOO
2ab5bc69e6 Add Flash, cuDNN, Efficient attention for Flux 2025-04-11 23:14:41 -04:00

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
from library import custom_offloading_utils from library import custom_offloading_utils
@@ -445,11 +446,13 @@ configs = {
# region math # region math
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)") x = rearrange(x, "B H L D -> B L (H D)")
return x return x