This commit is contained in:
Dave Lage
2026-02-27 11:13:46 +08:00
committed by GitHub
2 changed files with 8 additions and 2 deletions

View File

@@ -112,6 +112,7 @@ from typing import Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange
from library.utils import setup_logging
@@ -564,6 +565,7 @@ class Downsample2D(nn.Module):
return hidden_states
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
class CrossAttention(nn.Module):
def __init__(
@@ -739,7 +741,8 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
with sdpa_kernel(kernels):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = rearrange(out, "b h n d -> b n (h d)", h=h)

View File

@@ -29,6 +29,7 @@ import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from einops import rearrange
from library.utils import setup_logging
@@ -387,6 +388,7 @@ class Downsample2D(nn.Module):
return hidden_states
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
class CrossAttention(nn.Module):
def __init__(
@@ -545,7 +547,8 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
with sdpa_kernel(kernels):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = rearrange(out, "b h n d -> b n (h d)", h=h)