Files
Kohya-ss-sd-scripts/library/attention.py

51 lines
1.8 KiB
Python

import torch
from typing import Optional
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0
) -> torch.Tensor:
"""
Compute scaled dot-product attention with variable sequence lengths.
Handles batches with different sequence lengths by splitting and
processing each sequence individually.
Args:
q: Query tensor [B, L, H, D].
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
seq_lens: Valid sequence length for each batch element.
attn_mode: Attention implementation ("torch" or "sageattn").
drop_rate: Attention dropout rate.
Returns:
Attention output tensor [B, L, H*D].
"""
# Determine tensor layout based on attention implementation
if attn_mode == "torch" or attn_mode == "sageattn":
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
else:
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# Process each batch element with its valid sequence length
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
if attn_mode == "torch":
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(x_i)
x = torch.cat(x, dim=0)
del q, k, v
# Currently only PyTorch SDPA is implemented
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
return x