diff --git a/library/flux_models.py b/library/flux_models.py index 1e187288..0819bad8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -451,7 +451,7 @@ kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EF def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - with sdpa_kernel(kernels, set_priority=True): + with sdpa_kernel(kernels): x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)")