mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
Merge 149917b185 into 1a3ec9ea74
This commit is contained in:
@@ -112,6 +112,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -564,6 +565,7 @@ class Downsample2D(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -739,7 +741,8 @@ class CrossAttention(nn.Module):
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
with sdpa_kernel(kernels):
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -387,6 +388,7 @@ class Downsample2D(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -545,7 +547,8 @@ class CrossAttention(nn.Module):
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
with sdpa_kernel(kernels):
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user