diff --git a/library/lumina_models.py b/library/lumina_models.py index 2d4c6527..e00dcf96 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -38,6 +38,11 @@ except: # flash_attn may not be available but it is not required pass +try: + from sageattention import sageattn +except: + pass + try: from apex.normalization import FusedRMSNorm as RMSNorm except: @@ -273,6 +278,7 @@ class JointAttention(nn.Module): n_kv_heads: Optional[int], qk_norm: bool, use_flash_attn=False, + use_sage_attn=False, ): """ Initialize the Attention module. @@ -312,13 +318,20 @@ class JointAttention(nn.Module): self.q_norm = self.k_norm = nn.Identity() self.use_flash_attn = use_flash_attn + self.use_sage_attn = use_sage_attn - # self.attention_processor = xformers.ops.memory_efficient_attention - self.attention_processor = F.scaled_dot_product_attention + if use_sage_attn : + self.attention_processor = self.sage_attn + else: + # self.attention_processor = xformers.ops.memory_efficient_attention + self.attention_processor = F.scaled_dot_product_attention def set_attention_processor(self, attention_processor): self.attention_processor = attention_processor + def get_attention_processor(self): + return self.attention_processor + def forward( self, x: Tensor, @@ -354,7 +367,15 @@ class JointAttention(nn.Module): softmax_scale = math.sqrt(1 / self.head_dim) - if self.use_flash_attn: + if self.use_sage_attn: + # Handle GQA (Grouped Query Attention) if needed + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale) + elif self.use_flash_attn: output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) else: n_rep = self.n_local_heads // self.n_local_kv_heads @@ -430,6 +451,63 @@ class JointAttention(nn.Module): (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float): + try: + bsz = q.shape[0] + seqlen = q.shape[1] + + # Transpose tensors to match SageAttention's expected format (HND layout) + q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + + # Handle masking for SageAttention + # We need to filter out masked positions - this approach handles variable sequence lengths + outputs = [] + for b in range(bsz): + # Find valid token positions from the mask + valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1) + if valid_indices.numel() == 0: + # If all tokens are masked, create a zero output + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + else: + # Extract only valid tokens for this batch + batch_q = q_transposed[b, :, valid_indices, :] + batch_k = k_transposed[b, :, valid_indices, :] + batch_v = v_transposed[b, :, valid_indices, :] + + # Run SageAttention on valid tokens only + batch_output_valid = sageattn( + batch_q.unsqueeze(0), # Add batch dimension back + batch_k.unsqueeze(0), + batch_v.unsqueeze(0), + tensor_layout="HND", + is_causal=False, + sm_scale=softmax_scale + ) + + # Create output tensor with zeros for masked positions + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + # Place valid outputs back in the right positions + batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2) + + outputs.append(batch_output) + + # Stack batch outputs and reshape to expected format + output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim] + except NameError as e: + raise RuntimeError( + f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}" + ) + + return output + def flash_attn( self, q: Tensor, @@ -573,6 +651,7 @@ class JointTransformerBlock(GradientCheckpointMixin): qk_norm: bool, modulation=True, use_flash_attn=False, + use_sage_attn=False, ) -> None: """ Initialize a TransformerBlock. @@ -595,7 +674,7 @@ class JointTransformerBlock(GradientCheckpointMixin): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn) self.feed_forward = FeedForward( dim=dim, hidden_dim=4 * dim, @@ -766,6 +845,7 @@ class NextDiT(nn.Module): axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512], use_flash_attn=False, + use_sage_attn=False, ) -> None: """ Initialize the NextDiT model. @@ -819,7 +899,6 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=False, - use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -845,7 +924,6 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, - use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -867,6 +945,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, use_flash_attn=use_flash_attn, + use_sage_attn=use_sage_attn, ) for layer_id in range(n_layers) ] diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 012922ec..f224e86c 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1083,6 +1083,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): action="store_true", help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", ) + parser.add_argument( + "--use_sage_attn", + action="store_true", + help="Use Sage Attention for the model / モデルにSage Attentionを使用する", + ) parser.add_argument( "--system_prompt", type=str, diff --git a/library/lumina_util.py b/library/lumina_util.py index d9c89938..06f089d4 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -27,6 +27,7 @@ def load_lumina_model( device: torch.device, disable_mmap: bool = False, use_flash_attn: bool = False, + use_sage_attn: bool = False, ): """ Load the Lumina model from the checkpoint path. @@ -43,7 +44,7 @@ def load_lumina_model( """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) diff --git a/lumina_train_network.py b/lumina_train_network.py index ab811bd5..6b7e7d22 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -58,6 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, + use_sage_attn=args.use_sage_attn ) if args.fp8_base: