This commit is contained in:
Dave Lage
2026-03-31 04:54:17 +00:00
committed by GitHub

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
from library import custom_offloading_utils from library import custom_offloading_utils
@@ -445,11 +446,13 @@ configs = {
# region math # region math
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)") x = rearrange(x, "B H L D -> B L (H D)")
return x return x