Fix samples, LoRA training. Add system prompt, use_flash_attn

This commit is contained in:
rockerBOO
2025-02-23 01:29:18 -05:00
parent 6597631b90
commit 025cca699b
10 changed files with 888 additions and 386 deletions

View File

@@ -14,14 +14,19 @@ from typing import List, Optional, Tuple
from dataclasses import dataclass
from einops import rearrange
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except:
pass
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
@@ -75,7 +80,15 @@ class LuminaParams:
@classmethod
def get_7b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 7B parameter model"""
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
return cls(
patch_size=2,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[64, 64, 64],
axes_lens=[300, 512, 512],
)
class GradientCheckpointMixin(nn.Module):
@@ -248,6 +261,7 @@ class JointAttention(nn.Module):
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
use_flash_attn=False,
):
"""
Initialize the Attention module.
@@ -286,7 +300,7 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
self.flash_attn = False
self.use_flash_attn = use_flash_attn
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
@@ -294,35 +308,63 @@ class JointAttention(nn.Module):
def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
def forward(
self,
x: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
) -> Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
x:
x_mask:
freqs_cis:
"""
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
bsz, seqlen, _ = x.shape
dtype = x.dtype
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = apply_rope(xq, freqs_cis=freqs_cis)
xk = apply_rope(xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
softmax_scale = math.sqrt(1 / self.head_dim)
if self.use_flash_attn:
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
)
output = output.flatten(-2)
return self.out(output)
# copied from huggingface modeling_llama.py
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
@@ -377,46 +419,17 @@ class JointAttention(nn.Module):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def forward(
def flash_attn(
self,
x: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
softmax_scale,
) -> Tensor:
"""
bsz, seqlen, _, _ = q.shape
Args:
x:
x_mask:
freqs_cis:
Returns:
"""
bsz, seqlen, _ = x.shape
dtype = x.dtype
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
softmax_scale = math.sqrt(1 / self.head_dim)
if self.flash_attn:
try:
# begin var_len flash attn
(
query_states,
@@ -425,7 +438,7 @@ class JointAttention(nn.Module):
indices_q,
cu_seq_lens,
max_seq_lens,
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
) = self._upad_input(q, k, v, x_mask, seqlen)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
@@ -445,27 +458,12 @@ class JointAttention(nn.Module):
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
# end var_len_flash_attn
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
return output
except NameError as e:
raise RuntimeError(
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
output = output.flatten(-2)
return self.out(output)
def apply_rope(
x_in: torch.Tensor,
@@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
norm_eps: float,
qk_norm: bool,
modulation=True,
use_flash_attn=False,
) -> None:
"""
Initialize a TransformerBlock.
@@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
@@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin):
class RopeEmbedder:
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
def __init__(
self,
theta: float = 10000.0,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
):
super().__init__()
self.theta = theta
self.axes_dims = axes_dims
@@ -750,6 +754,7 @@ class NextDiT(nn.Module):
cap_feat_dim: int = 5120,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
use_flash_attn=False,
) -> None:
"""
Initialize the NextDiT model.
@@ -803,6 +808,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=False,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -828,6 +834,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -848,6 +855,7 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_layers)
]
@@ -988,8 +996,20 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(position_ids)
# Create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
cap_freqs_cis = torch.zeros(
bsz,
encoder_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)
img_freqs_cis = torch.zeros(
bsz,
image_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]