mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Fix samples, LoRA training. Add system prompt, use_flash_attn
This commit is contained in:
@@ -14,14 +14,19 @@ from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
@@ -75,7 +80,15 @@ class LuminaParams:
|
||||
@classmethod
|
||||
def get_7b_config(cls) -> "LuminaParams":
|
||||
"""Returns the configuration for the 7B parameter model"""
|
||||
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
|
||||
return cls(
|
||||
patch_size=2,
|
||||
dim=4096,
|
||||
n_layers=32,
|
||||
n_heads=32,
|
||||
n_kv_heads=8,
|
||||
axes_dims=[64, 64, 64],
|
||||
axes_lens=[300, 512, 512],
|
||||
)
|
||||
|
||||
|
||||
class GradientCheckpointMixin(nn.Module):
|
||||
@@ -248,6 +261,7 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
use_flash_attn=False,
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module.
|
||||
@@ -286,7 +300,7 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
self.flash_attn = False
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||
self.attention_processor = F.scaled_dot_product_attention
|
||||
@@ -294,35 +308,63 @@ class JointAttention(nn.Module):
|
||||
def set_attention_processor(self, attention_processor):
|
||||
self.attention_processor = attention_processor
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_mask: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
"""
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in)
|
||||
bsz, seqlen, _ = x.shape
|
||||
dtype = x.dtype
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
xq = apply_rope(xq, freqs_cis=freqs_cis)
|
||||
xk = apply_rope(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if self.use_flash_attn:
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = (
|
||||
self.attention_processor(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
||||
scale=softmax_scale,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
output = output.flatten(-2)
|
||||
return self.out(output)
|
||||
|
||||
# copied from huggingface modeling_llama.py
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
@@ -377,46 +419,17 @@ class JointAttention(nn.Module):
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
def forward(
|
||||
def flash_attn(
|
||||
self,
|
||||
x: Tensor,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
x_mask: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
softmax_scale,
|
||||
) -> Tensor:
|
||||
"""
|
||||
bsz, seqlen, _, _ = q.shape
|
||||
|
||||
Args:
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
bsz, seqlen, _ = x.shape
|
||||
dtype = x.dtype
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if self.flash_attn:
|
||||
try:
|
||||
# begin var_len flash attn
|
||||
(
|
||||
query_states,
|
||||
@@ -425,7 +438,7 @@ class JointAttention(nn.Module):
|
||||
indices_q,
|
||||
cu_seq_lens,
|
||||
max_seq_lens,
|
||||
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
|
||||
) = self._upad_input(q, k, v, x_mask, seqlen)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
@@ -445,27 +458,12 @@ class JointAttention(nn.Module):
|
||||
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
||||
# end var_len_flash_attn
|
||||
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = (
|
||||
self.attention_processor(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
||||
scale=softmax_scale,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.to(dtype)
|
||||
return output
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||
)
|
||||
|
||||
output = output.flatten(-2)
|
||||
return self.out(output)
|
||||
|
||||
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
@@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
use_flash_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TransformerBlock.
|
||||
@@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
@@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin):
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 10000.0,
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
@@ -750,6 +754,7 @@ class NextDiT(nn.Module):
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
use_flash_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the NextDiT model.
|
||||
@@ -803,6 +808,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -828,6 +834,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -848,6 +855,7 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
@@ -988,8 +996,20 @@ class NextDiT(nn.Module):
|
||||
freqs_cis = self.rope_embedder(position_ids)
|
||||
|
||||
# Create separate rotary embeddings for captions and images
|
||||
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
cap_freqs_cis = torch.zeros(
|
||||
bsz,
|
||||
encoder_seq_len,
|
||||
freqs_cis.shape[-1],
|
||||
device=device,
|
||||
dtype=freqs_cis.dtype,
|
||||
)
|
||||
img_freqs_cis = torch.zeros(
|
||||
bsz,
|
||||
image_seq_len,
|
||||
freqs_cis.shape[-1],
|
||||
device=device,
|
||||
dtype=freqs_cis.dtype,
|
||||
)
|
||||
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
|
||||
Reference in New Issue
Block a user