Merge pull request #22 from rockerBOO/sage_attn

Add Sage Attention for Lumina
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-03 10:26:02 +08:00
committed by GitHub
4 changed files with 93 additions and 7 deletions

View File

@@ -38,6 +38,11 @@ except:
# flash_attn may not be available but it is not required
pass
try:
from sageattention import sageattn
except:
pass
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
@@ -273,6 +278,7 @@ class JointAttention(nn.Module):
n_kv_heads: Optional[int],
qk_norm: bool,
use_flash_attn=False,
use_sage_attn=False,
):
"""
Initialize the Attention module.
@@ -312,13 +318,20 @@ class JointAttention(nn.Module):
self.q_norm = self.k_norm = nn.Identity()
self.use_flash_attn = use_flash_attn
self.use_sage_attn = use_sage_attn
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
if use_sage_attn :
self.attention_processor = self.sage_attn
else:
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor
def get_attention_processor(self):
return self.attention_processor
def forward(
self,
x: Tensor,
@@ -354,7 +367,15 @@ class JointAttention(nn.Module):
softmax_scale = math.sqrt(1 / self.head_dim)
if self.use_flash_attn:
if self.use_sage_attn:
# Handle GQA (Grouped Query Attention) if needed
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale)
elif self.use_flash_attn:
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
@@ -430,6 +451,63 @@ class JointAttention(nn.Module):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float):
try:
bsz = q.shape[0]
seqlen = q.shape[1]
# Transpose tensors to match SageAttention's expected format (HND layout)
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
# Handle masking for SageAttention
# We need to filter out masked positions - this approach handles variable sequence lengths
outputs = []
for b in range(bsz):
# Find valid token positions from the mask
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
if valid_indices.numel() == 0:
# If all tokens are masked, create a zero output
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
else:
# Extract only valid tokens for this batch
batch_q = q_transposed[b, :, valid_indices, :]
batch_k = k_transposed[b, :, valid_indices, :]
batch_v = v_transposed[b, :, valid_indices, :]
# Run SageAttention on valid tokens only
batch_output_valid = sageattn(
batch_q.unsqueeze(0), # Add batch dimension back
batch_k.unsqueeze(0),
batch_v.unsqueeze(0),
tensor_layout="HND",
is_causal=False,
sm_scale=softmax_scale
)
# Create output tensor with zeros for masked positions
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
# Place valid outputs back in the right positions
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
outputs.append(batch_output)
# Stack batch outputs and reshape to expected format
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
except NameError as e:
raise RuntimeError(
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
)
return output
def flash_attn(
self,
q: Tensor,
@@ -573,6 +651,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
qk_norm: bool,
modulation=True,
use_flash_attn=False,
use_sage_attn=False,
) -> None:
"""
Initialize a TransformerBlock.
@@ -595,7 +674,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
@@ -766,6 +845,7 @@ class NextDiT(nn.Module):
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
use_flash_attn=False,
use_sage_attn=False,
) -> None:
"""
Initialize the NextDiT model.
@@ -819,7 +899,6 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=False,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -845,7 +924,6 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -867,6 +945,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
use_flash_attn=use_flash_attn,
use_sage_attn=use_sage_attn,
)
for layer_id in range(n_layers)
]

View File

@@ -1083,6 +1083,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
)
parser.add_argument(
"--use_sage_attn",
action="store_true",
help="Use Sage Attention for the model / モデルにSage Attentionを使用する",
)
parser.add_argument(
"--system_prompt",
type=str,

View File

@@ -27,6 +27,7 @@ def load_lumina_model(
device: torch.device,
disable_mmap: bool = False,
use_flash_attn: bool = False,
use_sage_attn: bool = False,
):
"""
Load the Lumina model from the checkpoint path.
@@ -43,7 +44,7 @@ def load_lumina_model(
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype)
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)

View File

@@ -58,6 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn
)
if args.fp8_base: