From 2ab5bc69e657fa6d242ea8fe83b01c72c9ec0622 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 23:12:29 -0400 Subject: [PATCH 1/2] Add Flash, cuDNN, Efficient attention for Flux --- library/flux_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..1e187288 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -18,6 +18,7 @@ import torch from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint +from torch.nn.attention import SDPBackend, sdpa_kernel from library import custom_offloading_utils @@ -445,11 +446,13 @@ configs = { # region math +kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + with sdpa_kernel(kernels, set_priority=True): + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x From 90d14b9eb00d22ac7790e22818be7e534a2939b1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 12 Apr 2025 04:09:39 -0400 Subject: [PATCH 2/2] Remove priority --- library/flux_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index 1e187288..0819bad8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -451,7 +451,7 @@ kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EF def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - with sdpa_kernel(kernels, set_priority=True): + with sdpa_kernel(kernels): x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)")