mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Compare commits
3 Commits
sd3
...
12ea9b2ec5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12ea9b2ec5 | ||
|
|
90d14b9eb0 | ||
|
|
2ab5bc69e6 |
@@ -18,6 +18,7 @@ import torch
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
|
|
||||||
@@ -445,11 +446,13 @@ configs = {
|
|||||||
|
|
||||||
# region math
|
# region math
|
||||||
|
|
||||||
|
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
q, k = apply_rope(q, k, pe)
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
with sdpa_kernel(kernels):
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
x = rearrange(x, "B H L D -> B L (H D)")
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user