mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #22 from rockerBOO/sage_attn
Add Sage Attention for Lumina
This commit is contained in:
@@ -38,6 +38,11 @@ except:
|
|||||||
# flash_attn may not be available but it is not required
|
# flash_attn may not be available but it is not required
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||||
except:
|
except:
|
||||||
@@ -273,6 +278,7 @@ class JointAttention(nn.Module):
|
|||||||
n_kv_heads: Optional[int],
|
n_kv_heads: Optional[int],
|
||||||
qk_norm: bool,
|
qk_norm: bool,
|
||||||
use_flash_attn=False,
|
use_flash_attn=False,
|
||||||
|
use_sage_attn=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Attention module.
|
Initialize the Attention module.
|
||||||
@@ -312,13 +318,20 @@ class JointAttention(nn.Module):
|
|||||||
self.q_norm = self.k_norm = nn.Identity()
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
self.use_flash_attn = use_flash_attn
|
self.use_flash_attn = use_flash_attn
|
||||||
|
self.use_sage_attn = use_sage_attn
|
||||||
|
|
||||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
if use_sage_attn :
|
||||||
self.attention_processor = F.scaled_dot_product_attention
|
self.attention_processor = self.sage_attn
|
||||||
|
else:
|
||||||
|
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||||
|
self.attention_processor = F.scaled_dot_product_attention
|
||||||
|
|
||||||
def set_attention_processor(self, attention_processor):
|
def set_attention_processor(self, attention_processor):
|
||||||
self.attention_processor = attention_processor
|
self.attention_processor = attention_processor
|
||||||
|
|
||||||
|
def get_attention_processor(self):
|
||||||
|
return self.attention_processor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -354,7 +367,15 @@ class JointAttention(nn.Module):
|
|||||||
|
|
||||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||||
|
|
||||||
if self.use_flash_attn:
|
if self.use_sage_attn:
|
||||||
|
# Handle GQA (Grouped Query Attention) if needed
|
||||||
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
if n_rep >= 1:
|
||||||
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
|
||||||
|
output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||||
|
elif self.use_flash_attn:
|
||||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||||
else:
|
else:
|
||||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
@@ -430,6 +451,63 @@ class JointAttention(nn.Module):
|
|||||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float):
|
||||||
|
try:
|
||||||
|
bsz = q.shape[0]
|
||||||
|
seqlen = q.shape[1]
|
||||||
|
|
||||||
|
# Transpose tensors to match SageAttention's expected format (HND layout)
|
||||||
|
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||||
|
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||||
|
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
# Handle masking for SageAttention
|
||||||
|
# We need to filter out masked positions - this approach handles variable sequence lengths
|
||||||
|
outputs = []
|
||||||
|
for b in range(bsz):
|
||||||
|
# Find valid token positions from the mask
|
||||||
|
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
||||||
|
if valid_indices.numel() == 0:
|
||||||
|
# If all tokens are masked, create a zero output
|
||||||
|
batch_output = torch.zeros(
|
||||||
|
seqlen, self.n_local_heads, self.head_dim,
|
||||||
|
device=q.device, dtype=q.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Extract only valid tokens for this batch
|
||||||
|
batch_q = q_transposed[b, :, valid_indices, :]
|
||||||
|
batch_k = k_transposed[b, :, valid_indices, :]
|
||||||
|
batch_v = v_transposed[b, :, valid_indices, :]
|
||||||
|
|
||||||
|
# Run SageAttention on valid tokens only
|
||||||
|
batch_output_valid = sageattn(
|
||||||
|
batch_q.unsqueeze(0), # Add batch dimension back
|
||||||
|
batch_k.unsqueeze(0),
|
||||||
|
batch_v.unsqueeze(0),
|
||||||
|
tensor_layout="HND",
|
||||||
|
is_causal=False,
|
||||||
|
sm_scale=softmax_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create output tensor with zeros for masked positions
|
||||||
|
batch_output = torch.zeros(
|
||||||
|
seqlen, self.n_local_heads, self.head_dim,
|
||||||
|
device=q.device, dtype=q.dtype
|
||||||
|
)
|
||||||
|
# Place valid outputs back in the right positions
|
||||||
|
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||||
|
|
||||||
|
outputs.append(batch_output)
|
||||||
|
|
||||||
|
# Stack batch outputs and reshape to expected format
|
||||||
|
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
||||||
|
except NameError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def flash_attn(
|
def flash_attn(
|
||||||
self,
|
self,
|
||||||
q: Tensor,
|
q: Tensor,
|
||||||
@@ -573,6 +651,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
|||||||
qk_norm: bool,
|
qk_norm: bool,
|
||||||
modulation=True,
|
modulation=True,
|
||||||
use_flash_attn=False,
|
use_flash_attn=False,
|
||||||
|
use_sage_attn=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize a TransformerBlock.
|
Initialize a TransformerBlock.
|
||||||
@@ -595,7 +674,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
|
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn)
|
||||||
self.feed_forward = FeedForward(
|
self.feed_forward = FeedForward(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
hidden_dim=4 * dim,
|
hidden_dim=4 * dim,
|
||||||
@@ -766,6 +845,7 @@ class NextDiT(nn.Module):
|
|||||||
axes_dims: List[int] = [16, 56, 56],
|
axes_dims: List[int] = [16, 56, 56],
|
||||||
axes_lens: List[int] = [1, 512, 512],
|
axes_lens: List[int] = [1, 512, 512],
|
||||||
use_flash_attn=False,
|
use_flash_attn=False,
|
||||||
|
use_sage_attn=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the NextDiT model.
|
Initialize the NextDiT model.
|
||||||
@@ -819,7 +899,6 @@ class NextDiT(nn.Module):
|
|||||||
norm_eps,
|
norm_eps,
|
||||||
qk_norm,
|
qk_norm,
|
||||||
modulation=False,
|
modulation=False,
|
||||||
use_flash_attn=use_flash_attn,
|
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
@@ -845,7 +924,6 @@ class NextDiT(nn.Module):
|
|||||||
norm_eps,
|
norm_eps,
|
||||||
qk_norm,
|
qk_norm,
|
||||||
modulation=True,
|
modulation=True,
|
||||||
use_flash_attn=use_flash_attn,
|
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
@@ -867,6 +945,7 @@ class NextDiT(nn.Module):
|
|||||||
norm_eps,
|
norm_eps,
|
||||||
qk_norm,
|
qk_norm,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
|
use_sage_attn=use_sage_attn,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1083,6 +1083,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
|
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_sage_attn",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Sage Attention for the model / モデルにSage Attentionを使用する",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--system_prompt",
|
"--system_prompt",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ def load_lumina_model(
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
disable_mmap: bool = False,
|
disable_mmap: bool = False,
|
||||||
use_flash_attn: bool = False,
|
use_flash_attn: bool = False,
|
||||||
|
use_sage_attn: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load the Lumina model from the checkpoint path.
|
Load the Lumina model from the checkpoint path.
|
||||||
@@ -43,7 +44,7 @@ def load_lumina_model(
|
|||||||
"""
|
"""
|
||||||
logger.info("Building Lumina")
|
logger.info("Building Lumina")
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype)
|
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||||
|
|
||||||
logger.info(f"Loading state dict from {ckpt_path}")
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
torch.device("cpu"),
|
torch.device("cpu"),
|
||||||
disable_mmap=args.disable_mmap_load_safetensors,
|
disable_mmap=args.disable_mmap_load_safetensors,
|
||||||
use_flash_attn=args.use_flash_attn,
|
use_flash_attn=args.use_flash_attn,
|
||||||
|
use_sage_attn=args.use_sage_attn
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.fp8_base:
|
if args.fp8_base:
|
||||||
|
|||||||
Reference in New Issue
Block a user