mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
import torch
|
|
from typing import Optional
|
|
|
|
|
|
def attention(
|
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute scaled dot-product attention with variable sequence lengths.
|
|
|
|
Handles batches with different sequence lengths by splitting and
|
|
processing each sequence individually.
|
|
|
|
Args:
|
|
q: Query tensor [B, L, H, D].
|
|
k: Key tensor [B, L, H, D].
|
|
v: Value tensor [B, L, H, D].
|
|
seq_lens: Valid sequence length for each batch element.
|
|
attn_mode: Attention implementation ("torch" or "sageattn").
|
|
drop_rate: Attention dropout rate.
|
|
|
|
Returns:
|
|
Attention output tensor [B, L, H*D].
|
|
"""
|
|
# Determine tensor layout based on attention implementation
|
|
if attn_mode == "torch" or attn_mode == "sageattn":
|
|
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
|
|
else:
|
|
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
|
|
|
# Process each batch element with its valid sequence length
|
|
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
|
|
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
|
|
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
|
|
|
|
if attn_mode == "torch":
|
|
x = []
|
|
for i in range(len(q)):
|
|
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
|
q[i] = None
|
|
k[i] = None
|
|
v[i] = None
|
|
x.append(x_i)
|
|
x = torch.cat(x, dim=0)
|
|
del q, k, v
|
|
# Currently only PyTorch SDPA is implemented
|
|
|
|
x = transpose_fn(x) # [B, L, H, D]
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
|
return x
|