Remove priority

This commit is contained in:
rockerBOO
2025-04-12 04:09:39 -04:00
parent 2ab5bc69e6
commit 90d14b9eb0

View File

@@ -451,7 +451,7 @@ kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EF
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe)
with sdpa_kernel(kernels, set_priority=True):
with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")