mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Merge pull request #22 from rockerBOO/sage_attn
Add Sage Attention for Lumina
This commit is contained in:
@@ -38,6 +38,11 @@ except:
|
||||
# flash_attn may not be available but it is not required
|
||||
pass
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
@@ -273,6 +278,7 @@ class JointAttention(nn.Module):
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
use_flash_attn=False,
|
||||
use_sage_attn=False,
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module.
|
||||
@@ -312,13 +318,20 @@ class JointAttention(nn.Module):
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.use_sage_attn = use_sage_attn
|
||||
|
||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||
self.attention_processor = F.scaled_dot_product_attention
|
||||
if use_sage_attn :
|
||||
self.attention_processor = self.sage_attn
|
||||
else:
|
||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||
self.attention_processor = F.scaled_dot_product_attention
|
||||
|
||||
def set_attention_processor(self, attention_processor):
|
||||
self.attention_processor = attention_processor
|
||||
|
||||
def get_attention_processor(self):
|
||||
return self.attention_processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@@ -354,7 +367,15 @@ class JointAttention(nn.Module):
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if self.use_flash_attn:
|
||||
if self.use_sage_attn:
|
||||
# Handle GQA (Grouped Query Attention) if needed
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
elif self.use_flash_attn:
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
@@ -430,6 +451,63 @@ class JointAttention(nn.Module):
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float):
|
||||
try:
|
||||
bsz = q.shape[0]
|
||||
seqlen = q.shape[1]
|
||||
|
||||
# Transpose tensors to match SageAttention's expected format (HND layout)
|
||||
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||
|
||||
# Handle masking for SageAttention
|
||||
# We need to filter out masked positions - this approach handles variable sequence lengths
|
||||
outputs = []
|
||||
for b in range(bsz):
|
||||
# Find valid token positions from the mask
|
||||
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
||||
if valid_indices.numel() == 0:
|
||||
# If all tokens are masked, create a zero output
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
# Extract only valid tokens for this batch
|
||||
batch_q = q_transposed[b, :, valid_indices, :]
|
||||
batch_k = k_transposed[b, :, valid_indices, :]
|
||||
batch_v = v_transposed[b, :, valid_indices, :]
|
||||
|
||||
# Run SageAttention on valid tokens only
|
||||
batch_output_valid = sageattn(
|
||||
batch_q.unsqueeze(0), # Add batch dimension back
|
||||
batch_k.unsqueeze(0),
|
||||
batch_v.unsqueeze(0),
|
||||
tensor_layout="HND",
|
||||
is_causal=False,
|
||||
sm_scale=softmax_scale
|
||||
)
|
||||
|
||||
# Create output tensor with zeros for masked positions
|
||||
batch_output = torch.zeros(
|
||||
seqlen, self.n_local_heads, self.head_dim,
|
||||
device=q.device, dtype=q.dtype
|
||||
)
|
||||
# Place valid outputs back in the right positions
|
||||
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
# Stack batch outputs and reshape to expected format
|
||||
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def flash_attn(
|
||||
self,
|
||||
q: Tensor,
|
||||
@@ -573,6 +651,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
use_flash_attn=False,
|
||||
use_sage_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TransformerBlock.
|
||||
@@ -595,7 +674,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
@@ -766,6 +845,7 @@ class NextDiT(nn.Module):
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
use_flash_attn=False,
|
||||
use_sage_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the NextDiT model.
|
||||
@@ -819,7 +899,6 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -845,7 +924,6 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -867,6 +945,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
use_flash_attn=use_flash_attn,
|
||||
use_sage_attn=use_sage_attn,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
|
||||
@@ -1083,6 +1083,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_sage_attn",
|
||||
action="store_true",
|
||||
help="Use Sage Attention for the model / モデルにSage Attentionを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
|
||||
@@ -27,6 +27,7 @@ def load_lumina_model(
|
||||
device: torch.device,
|
||||
disable_mmap: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
use_sage_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Load the Lumina model from the checkpoint path.
|
||||
@@ -43,7 +44,7 @@ def load_lumina_model(
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype)
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
@@ -58,6 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
use_sage_attn=args.use_sage_attn
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
|
||||
Reference in New Issue
Block a user