Add documentation to model, use SDPA attention, sample images

This commit is contained in:
rockerBOO
2025-02-18 00:58:53 -05:00
parent 1aa2f00e85
commit 98efbc3bb7
5 changed files with 643 additions and 333 deletions

View File

@@ -13,6 +13,7 @@ import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from einops import rearrange
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import torch import torch
@@ -23,24 +24,16 @@ import torch.nn.functional as F
try: try:
from apex.normalization import FusedRMSNorm as RMSNorm from apex.normalization import FusedRMSNorm as RMSNorm
except ModuleNotFoundError: except:
import warnings import warnings
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
memory_efficient_attention = None
try:
import xformers
except:
pass
try:
from xformers.ops import memory_efficient_attention
except:
memory_efficient_attention = None
@dataclass @dataclass
class LuminaParams: class LuminaParams:
"""Parameters for Lumina model configuration""" """Parameters for Lumina model configuration"""
patch_size: int = 2 patch_size: int = 2
in_channels: int = 4 in_channels: int = 4
dim: int = 4096 dim: int = 4096
@@ -68,7 +61,7 @@ class LuminaParams:
"""Returns the configuration for the 2B parameter model""" """Returns the configuration for the 2B parameter model"""
return cls( return cls(
patch_size=2, patch_size=2,
in_channels=16, in_channels=16, # VAE channels
dim=2304, dim=2304,
n_layers=26, n_layers=26,
n_heads=24, n_heads=24,
@@ -76,21 +69,13 @@ class LuminaParams:
axes_dims=[32, 32, 32], axes_dims=[32, 32, 32],
axes_lens=[300, 512, 512], axes_lens=[300, 512, 512],
qk_norm=True, qk_norm=True,
cap_feat_dim=2304 cap_feat_dim=2304, # Gemma 2 hidden_size
) )
@classmethod @classmethod
def get_7b_config(cls) -> "LuminaParams": def get_7b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 7B parameter model""" """Returns the configuration for the 7B parameter model"""
return cls( return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
patch_size=2,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[64, 64, 64],
axes_lens=[300, 512, 512]
)
class GradientCheckpointMixin(nn.Module): class GradientCheckpointMixin(nn.Module):
@@ -112,6 +97,7 @@ class GradientCheckpointMixin(nn.Module):
else: else:
return self._forward(*args, **kwargs) return self._forward(*args, **kwargs)
############################################################################# #############################################################################
# RMSNorm # # RMSNorm #
############################################################################# #############################################################################
@@ -148,9 +134,18 @@ class RMSNorm(torch.nn.Module):
""" """
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor): def forward(self, x: Tensor):
"""
Apply RMSNorm to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
x_dtype = x.dtype x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float() x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
@@ -204,17 +199,11 @@ class TimestepEmbedder(GradientCheckpointMixin):
""" """
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = torch.cat( embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding return embedding
def _forward(self, t): def _forward(self, t):
@@ -222,6 +211,7 @@ class TimestepEmbedder(GradientCheckpointMixin):
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb return t_emb
def to_cuda(x): def to_cuda(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.cuda() return x.cuda()
@@ -266,6 +256,7 @@ class JointAttention(nn.Module):
dim (int): Number of input dimensions. dim (int): Number of input dimensions.
n_heads (int): Number of heads. n_heads (int): Number of heads.
n_kv_heads (Optional[int]): Number of kv heads, if using GQA. n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
qk_norm (bool): Whether to use normalization for queries and keys.
""" """
super().__init__() super().__init__()
@@ -295,6 +286,14 @@ class JointAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
self.flash_attn = False
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor
@staticmethod @staticmethod
def apply_rotary_emb( def apply_rotary_emb(
x_in: torch.Tensor, x_in: torch.Tensor,
@@ -326,16 +325,12 @@ class JointAttention(nn.Module):
return x_out.type_as(x_in) return x_out.type_as(x_in)
# copied from huggingface modeling_llama.py # copied from huggingface modeling_llama.py
def _upad_input( def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
self, query_layer, key_layer, value_layer, attention_mask, query_length
):
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad( cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
return ( return (
indices, indices,
cu_seqlens, cu_seqlens,
@@ -355,9 +350,7 @@ class JointAttention(nn.Module):
) )
if query_length == kv_seq_len: if query_length == kv_seq_len:
query_layer = index_first_axis( query_layer = index_first_axis(
query_layer.reshape( query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
batch_size * kv_seq_len, self.n_local_heads, head_dim
),
indices_k, indices_k,
) )
cu_seqlens_q = cu_seqlens_k cu_seqlens_q = cu_seqlens_k
@@ -373,9 +366,7 @@ class JointAttention(nn.Module):
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
query_layer, attention_mask
)
return ( return (
query_layer, query_layer,
@@ -388,10 +379,10 @@ class JointAttention(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: Tensor,
x_mask: torch.Tensor, x_mask: Tensor,
freqs_cis: torch.Tensor, freqs_cis: Tensor,
) -> torch.Tensor: ) -> Tensor:
""" """
Args: Args:
@@ -425,7 +416,7 @@ class JointAttention(nn.Module):
softmax_scale = math.sqrt(1 / self.head_dim) softmax_scale = math.sqrt(1 / self.head_dim)
if dtype in [torch.float16, torch.bfloat16]: if self.flash_attn:
# begin var_len flash attn # begin var_len flash attn
( (
query_states, query_states,
@@ -459,14 +450,13 @@ class JointAttention(nn.Module):
if n_rep >= 1: if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = ( output = (
F.scaled_dot_product_attention( self.attention_processor(
xq.permute(0, 2, 1, 3), xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3), xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3), xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool() attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
.view(bsz, 1, 1, seqlen)
.expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale, scale=softmax_scale,
) )
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
@@ -474,10 +464,47 @@ class JointAttention(nn.Module):
) )
output = output.flatten(-2) output = output.flatten(-2)
return self.out(output) return self.out(output)
def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__( def __init__(
self, self,
@@ -554,10 +581,13 @@ class JointTransformerBlock(GradientCheckpointMixin):
n_kv_heads (Optional[int]): Number of attention heads in key and n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as value features (if using GQA), or set to None for the same as
query. query.
multiple_of (int): multiple_of (int): Number of multiple of the hidden dimension.
ffn_dim_multiplier (Optional[float]): ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
norm_eps (float): feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use normalization for queries and keys.
modulation (bool): Whether to use modulation for the attention
layer.
""" """
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -593,32 +623,30 @@ class JointTransformerBlock(GradientCheckpointMixin):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_mask: torch.Tensor, x_mask: torch.Tensor,
freqs_cis: torch.Tensor, pe: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None, adaln_input: Optional[torch.Tensor] = None,
): ):
""" """
Perform a forward pass through the TransformerBlock. Perform a forward pass through the TransformerBlock.
Args: Args:
x (torch.Tensor): Input tensor. x (Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. pe (Tensor): Rope position embedding.
Returns: Returns:
torch.Tensor: Output tensor after applying attention and Tensor: Output tensor after applying attention and
feedforward layers. feedforward layers.
""" """
if self.modulation: if self.modulation:
assert adaln_input is not None assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
adaln_input
).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention( self.attention(
modulate(self.attention_norm1(x), scale_msa), modulate(self.attention_norm1(x), scale_msa),
x_mask, x_mask,
freqs_cis, pe,
) )
) )
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -632,7 +660,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
self.attention( self.attention(
self.attention_norm1(x), self.attention_norm1(x),
x_mask, x_mask,
freqs_cis, pe,
) )
) )
x = x + self.ffn_norm2( x = x + self.ffn_norm2(
@@ -649,6 +677,14 @@ class FinalLayer(GradientCheckpointMixin):
""" """
def __init__(self, hidden_size, patch_size, out_channels): def __init__(self, hidden_size, patch_size, out_channels):
"""
Initialize the FinalLayer.
Args:
hidden_size (int): Hidden size of the input features.
patch_size (int): Patch size of the input features.
out_channels (int): Number of output channels.
"""
super().__init__() super().__init__()
self.norm_final = nn.LayerNorm( self.norm_final = nn.LayerNorm(
hidden_size, hidden_size,
@@ -682,39 +718,21 @@ class FinalLayer(GradientCheckpointMixin):
class RopeEmbedder: class RopeEmbedder:
def __init__( def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
self,
theta: float = 10000.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
):
super().__init__() super().__init__()
self.theta = theta self.theta = theta
self.axes_dims = axes_dims self.axes_dims = axes_dims
self.axes_lens = axes_lens self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis( self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.axes_dims, self.axes_lens, theta=self.theta
)
def get_freqs_cis(self, ids: torch.Tensor): def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = [] result = []
for i in range(len(self.axes_dims)): for i in range(len(self.axes_dims)):
index = ( freqs = self.freqs_cis[i].to(ids.device)
ids[:, :, i : i + 1] index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
.repeat(1, 1, self.freqs_cis[i].shape[-1]) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
.to(torch.int64)
)
axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1)
result.append(
torch.gather(
axes,
dim=1,
index=index,
)
)
return torch.cat(result, dim=-1) return torch.cat(result, dim=-1)
@@ -740,11 +758,63 @@ class NextDiT(nn.Module):
axes_dims: List[int] = [16, 56, 56], axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512], axes_lens: List[int] = [1, 512, 512],
) -> None: ) -> None:
"""
Initialize the NextDiT model.
Args:
patch_size (int): Patch size of the input features.
in_channels (int): Number of input channels.
dim (int): Hidden size of the input features.
n_layers (int): Number of Transformer layers.
n_refiner_layers (int): Number of refiner layers.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int): Multiple of the hidden size.
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use query key normalization.
cap_feat_dim (int): Dimension of the caption features.
axes_dims (List[int]): List of dimensions for the axes.
axes_lens (List[int]): List of lengths for the axes.
Returns:
None
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
dim,
bias=True,
),
)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.x_embedder = nn.Linear( self.x_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels, in_features=patch_size * patch_size * in_channels,
out_features=dim, out_features=dim,
@@ -769,32 +839,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
] ]
) )
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
dim,
bias=True,
),
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
# nn.init.zeros_(self.cap_embedder[1].weight) # nn.init.zeros_(self.cap_embedder[1].weight)
nn.init.zeros_(self.cap_embedder[1].bias) nn.init.zeros_(self.cap_embedder[1].bias)
@@ -864,15 +909,26 @@ class NextDiT(nn.Module):
def unpatchify( def unpatchify(
self, self,
x: torch.Tensor, x: Tensor,
width: int, width: int,
height: int, height: int,
encoder_seq_lengths: List[int], encoder_seq_lengths: List[int],
seq_lengths: List[int], seq_lengths: List[int],
) -> torch.Tensor: ) -> Tensor:
""" """
Unpatchify the input tensor and embed the caption features.
x: (N, T, patch_size**2 * C) x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C) imgs: (N, H, W, C)
Args:
x (Tensor): Input tensor.
width (int): Width of the input tensor.
height (int): Height of the input tensor.
encoder_seq_lengths (List[int]): List of encoder sequence lengths.
seq_lengths (List[int]): List of sequence lengths
Returns:
output: (N, C, H, W)
""" """
pH = pW = self.patch_size pH = pW = self.patch_size
@@ -891,13 +947,27 @@ class NextDiT(nn.Module):
def patchify_and_embed( def patchify_and_embed(
self, self,
x: torch.Tensor, x: Tensor,
cap_feats: torch.Tensor, cap_feats: Tensor,
cap_mask: torch.Tensor, cap_mask: Tensor,
t: torch.Tensor, t: Tensor,
) -> Tuple[ ) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int] """
]: Patchify and embed the input image and caption features.
Args:
x: (N, C, H, W) image latents
cap_feats: (N, C, D) caption features
cap_mask: (N, C, D) caption attention mask
t: (N), T timesteps
Returns:
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
"""
bsz, channels, height, width = x.shape bsz, channels, height, width = x.shape
pH = pW = self.patch_size pH = pW = self.patch_size
device = x.device device = x.device
@@ -915,40 +985,35 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = height // pH, width // pW H_tokens, W_tokens = height // pH, width // pW
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len position_ids[i, cap_len:seq_len, 0] = cap_len
row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device)
.view(-1, 1)
.repeat(1, W_tokens)
.flatten()
)
col_ids = (
torch.arange(W_tokens, dtype=torch.int32, device=device)
.view(1, -1)
.repeat(H_tokens, 1)
.flatten()
)
position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids
position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids
freqs_cis = self.rope_embedder.get_freqs_cis(position_ids) row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:seq_len, 1] = row_ids
position_ids[i, cap_len:seq_len, 2] = col_ids
# Get combinded rotary embeddings
freqs_cis = self.rope_embedder(position_ids)
# Create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len] img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len]
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) # Refine caption context
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
# refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
x = self.x_embedder(x) x = self.x_embedder(x)
# Refine image context
for layer in self.noise_refiner: for layer in self.noise_refiner:
x = layer(x, x_mask, img_freqs_cis, t) x = layer(x, x_mask, img_freqs_cis, t)
@@ -963,19 +1028,23 @@ class NextDiT(nn.Module):
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
def forward(self, x, t, cap_feats, cap_mask): def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor:
""" """
Forward pass of NextDiT. Forward pass of NextDiT.
t: (N,) tensor of diffusion timesteps Args:
y: (N,) tensor of text tokens/features x: (N, C, H, W) image latents
t: (N,) tensor of diffusion timesteps
cap_feats: (N, L, D) caption features
cap_mask: (N, L) caption attention mask
Returns:
x: (N, C, H, W) denoised latents
""" """
_, _, height, width = x.shape # B, C, H, W _, _, height, width = x.shape # B, C, H, W
t = self.t_embedder(t) # (N, D) t = self.t_embedder(t) # (N, D)
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed( x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)
x, cap_feats, cap_mask, t
)
for layer in self.layers: for layer in self.layers:
x = layer(x, mask, freqs_cis, t) x = layer(x, mask, freqs_cis, t)
@@ -986,7 +1055,14 @@ class NextDiT(nn.Module):
return x return x
def forward_with_cfg( def forward_with_cfg(
self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1 self,
x: Tensor,
t: Tensor,
cap_feats: Tensor,
cap_mask: Tensor,
cfg_scale: float,
cfg_trunc: int = 100,
renorm_cfg: float = 1.0,
): ):
""" """
Forward pass of NextDiT, but also batches the unconditional forward pass Forward pass of NextDiT, but also batches the unconditional forward pass
@@ -996,9 +1072,10 @@ class NextDiT(nn.Module):
half = x[: len(x) // 2] half = x[: len(x) // 2]
if t[0] < cfg_trunc: if t[0] < cfg_trunc:
combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
model_out = self.forward( assert (
combined, t, cap_feats, cap_mask cap_mask.shape[0] == combined.shape[0]
) # [2, 16, 128, 128] ), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}"
model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128]
# For exact reproducibility reasons, we apply classifier-free guidance on only # For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels. # three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that. # This can be done by uncommenting the following line and commenting-out the line following that.
@@ -1009,13 +1086,9 @@ class NextDiT(nn.Module):
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if float(renorm_cfg) > 0.0: if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm( ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True)
cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg) max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm( new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True)
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm: if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm) half_eps = half_eps * (max_new_norm / new_pos_norm)
else: else:
@@ -1040,7 +1113,7 @@ class NextDiT(nn.Module):
dim: List[int], dim: List[int],
end: List[int], end: List[int],
theta: float = 10000.0, theta: float = 10000.0,
): ) -> List[Tensor]:
""" """
Precompute the frequency tensor for complex exponentials (cis) with Precompute the frequency tensor for complex exponentials (cis) with
given dimensions. given dimensions.
@@ -1057,19 +1130,17 @@ class NextDiT(nn.Module):
Defaults to 10000.0. Defaults to 10000.0.
Returns: Returns:
torch.Tensor: Precomputed frequency tensor with complex List[torch.Tensor]: Precomputed frequency tensor with complex
exponentials. exponentials.
""" """
freqs_cis = [] freqs_cis = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(dim, end)): for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / ( pos = torch.arange(e, dtype=freqs_dtype, device="cpu")
theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d))
) freqs = torch.outer(pos, freqs)
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(
torch.complex64
) # complex64
freqs_cis.append(freqs_cis_i) freqs_cis.append(freqs_cis_i)
return freqs_cis return freqs_cis
@@ -1102,7 +1173,7 @@ class NextDiT(nn.Module):
def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs):
if params is None: if params is None:
params = LuminaParams.get_2b_config() params = LuminaParams.get_2b_config()
return NextDiT( return NextDiT(
patch_size=params.patch_size, patch_size=params.patch_size,
in_channels=params.in_channels, in_channels=params.in_channels,

View File

@@ -2,20 +2,20 @@ import argparse
import math import math
import os import os
import numpy as np import numpy as np
import toml
import json
import time import time
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Any
import torch import torch
from torch import Tensor
from accelerate import Accelerator, PartialState from accelerate import Accelerator, PartialState
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import Gemma2Model
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
from safetensors.torch import save_file from safetensors.torch import save_file
from library import lumina_models, lumina_util, strategy_base, train_util from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex() init_ipex()
@@ -30,19 +30,38 @@ logger = logging.getLogger(__name__)
# region sample images # region sample images
@torch.no_grad()
def sample_images( def sample_images(
accelerator: Accelerator, accelerator: Accelerator,
args: argparse.Namespace, args: argparse.Namespace,
epoch, epoch: int,
steps, global_step: int,
nextdit, nextdit: lumina_models.NextDiT,
ae, vae: torch.nn.Module,
gemma2_model, gemma2_model: Gemma2Model,
sample_prompts_gemma2_outputs, sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
prompt_replacement=None, prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None controlnet=None,
): ):
if steps == 0: """
Generate sample images using the NextDiT model.
Args:
accelerator (Accelerator): Accelerator instance.
args (argparse.Namespace): Command-line arguments.
epoch (int): Current epoch number.
global_step (int): Current global step number.
nextdit (lumina_models.NextDiT): The NextDiT model instance.
vae (torch.nn.Module): The VAE module.
gemma2_model (Gemma2Model): The Gemma2 model instance.
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample.
prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None.
controlnet:: ControlNet model
Returns:
None
"""
if global_step == 0:
if not args.sample_at_first: if not args.sample_at_first:
return return
else: else:
@@ -53,11 +72,15 @@ def sample_images(
if epoch is None or epoch % args.sample_every_n_epochs != 0: if epoch is None or epoch % args.sample_every_n_epochs != 0:
return return
else: else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return return
assert (
args.sample_prompts is not None
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
logger.info("") logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return return
@@ -87,22 +110,21 @@ def sample_images(
if distributed_state.num_processes <= 1: if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts:
for prompt_dict in prompts: sample_image_inference(
sample_image_inference( accelerator,
accelerator, args,
args, nextdit,
nextdit, gemma2_model,
gemma2_model, vae,
ae, save_dir,
save_dir, prompt_dict,
prompt_dict, epoch,
epoch, global_step,
steps, sample_prompts_gemma2_outputs,
sample_prompts_gemma2_outputs, prompt_replacement,
prompt_replacement, controlnet,
controlnet )
)
else: else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
@@ -110,23 +132,22 @@ def sample_images(
for i in range(distributed_state.num_processes): for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes]) per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad(): with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists[0]:
for prompt_dict in prompt_dict_lists[0]: sample_image_inference(
sample_image_inference( accelerator,
accelerator, args,
args, nextdit,
nextdit, gemma2_model,
gemma2_model, vae,
ae, save_dir,
save_dir, prompt_dict,
prompt_dict, epoch,
epoch, global_step,
steps, sample_prompts_gemma2_outputs,
sample_prompts_gemma2_outputs, prompt_replacement,
prompt_replacement, controlnet,
controlnet )
)
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
if cuda_rng_state is not None: if cuda_rng_state is not None:
@@ -135,43 +156,60 @@ def sample_images(
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@torch.no_grad()
def sample_image_inference( def sample_image_inference(
accelerator: Accelerator, accelerator: Accelerator,
args: argparse.Namespace, args: argparse.Namespace,
nextdit, nextdit: lumina_models.NextDiT,
gemma2_model, gemma2_model: Gemma2Model,
ae, vae: torch.nn.Module,
save_dir, save_dir: str,
prompt_dict, prompt_dict: Dict[str, str],
epoch, epoch: int,
steps, global_step: int,
sample_prompts_gemma2_outputs, sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
prompt_replacement, prompt_replacement: Optional[Tuple[str, str]] = None,
# controlnet controlnet=None,
): ):
"""
Generates sample images
Args:
accelerator (Accelerator): Accelerator object
args (argparse.Namespace): Arguments object
nextdit (lumina_models.NextDiT): NextDiT model
gemma2_model (Gemma2Model): Gemma2 model
vae (torch.nn.Module): VAE model
save_dir (str): Directory to save images
prompt_dict (Dict[str, str]): Prompt dictionary
epoch (int): Epoch number
steps (int): Number of steps to run
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
Returns:
None
"""
assert isinstance(prompt_dict, dict) assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt") # negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20) sample_steps = prompt_dict.get("sample_steps", 38)
width = prompt_dict.get("width", 512) width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 512) height = prompt_dict.get("height", 1024)
scale = prompt_dict.get("scale", 3.5) guidance_scale: int = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed") seed: int = prompt_dict.get("seed", None)
controlnet_image = prompt_dict.get("controlnet_image") controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "") prompt: str = prompt_dict.get("prompt", "")
negative_prompt: str = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None: if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None: if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
generator = torch.Generator(device=accelerator.device)
if seed is not None: if seed is not None:
torch.manual_seed(seed) generator.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
# if negative_prompt is None: # if negative_prompt is None:
# negative_prompt = "" # negative_prompt = ""
@@ -182,7 +220,7 @@ def sample_image_inference(
logger.info(f"height: {height}") logger.info(f"height: {height}")
logger.info(f"width: {width}") logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}") logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}") logger.info(f"scale: {guidance_scale}")
# logger.info(f"sample_sampler: {sampler_name}") # logger.info(f"sample_sampler: {sampler_name}")
if seed is not None: if seed is not None:
logger.info(f"seed: {seed}") logger.info(f"seed: {seed}")
@@ -191,14 +229,16 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
gemma2_conds = [] gemma2_conds = []
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt] gemma2_conds = sample_prompts_gemma2_outputs[prompt]
print(f"Using cached Gemma2 outputs for prompt: {prompt}") logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if gemma2_model is not None: if gemma2_model is not None:
print(f"Encoding prompt with Gemma2: {prompt}") logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt) tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_gemma2_attn_mask option
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# if gemma2_conds is not cached, use encoded_gemma2_conds # if gemma2_conds is not cached, use encoded_gemma2_conds
@@ -211,22 +251,26 @@ def sample_image_inference(
gemma2_conds[i] = encoded_gemma2_conds[i] gemma2_conds[i] = encoded_gemma2_conds[i]
# Unpack Gemma2 outputs # Unpack Gemma2 outputs
gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
# sample image # sample image
weight_dtype = ae.dtype # TOFO give dtype as argument weight_dtype = vae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16 latent_height = height // 8
packed_latent_width = width // 16 latent_width = width // 8
noise = torch.randn( noise = torch.randn(
1, 1,
packed_latent_height * packed_latent_width, 16,
16 * 2 * 2, latent_height,
latent_width,
device=accelerator.device, device=accelerator.device,
dtype=weight_dtype, dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, generator=generator,
) )
# Prompts are paired positive/negative
noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) # img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
# if controlnet_image is not None: # if controlnet_image is not None:
@@ -235,18 +279,18 @@ def sample_image_inference(
# controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad(): with accelerator.autocast():
x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale)
x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) # x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image # latent to image
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu org_vae_device = vae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad(): with accelerator.autocast():
x = ae.decode(x) x = vae.decode(x)
ae.to(org_vae_device) vae.to(org_vae_device)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1) x = x.clamp(-1, 1)
@@ -257,9 +301,9 @@ def sample_image_inference(
# but adding 'enum' to the filename should be enough # but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
seed_suffix = "" if seed is None else f"_{seed}" seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"] i: int = int(prompt_dict.get("enum", 0))
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename)) image.save(os.path.join(save_dir, img_filename))
@@ -273,11 +317,34 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor): def time_shift(mu: float, sigma: float, t: Tensor):
"""
Get time shift
Args:
mu (float): mu value.
sigma (float): sigma value.
t (Tensor): timestep.
Return:
float: time shift
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
"""
Get linear function
Args:
x1 (float, optional): x1 value. Defaults to 256.
y1 (float, optional): y1 value. Defaults to 0.5.
x2 (float, optional): x2 value. Defaults to 4096.
y2 (float, optional): y2 value. Defaults to 1.15.
Return:
Callable[[float], float]: linear function
"""
m = (y2 - y1) / (x2 - x1) m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1 b = y1 - m * x1
return lambda x: m * x + b return lambda x: m * x + b
@@ -290,6 +357,19 @@ def get_schedule(
max_shift: float = 1.15, max_shift: float = 1.15,
shift: bool = True, shift: bool = True,
) -> list[float]: ) -> list[float]:
"""
Get timesteps schedule
Args:
num_steps (int): Number of steps in the schedule.
image_seq_len (int): Sequence length of the image.
base_shift (float, optional): Base shift value. Defaults to 0.5.
max_shift (float, optional): Maximum shift value. Defaults to 1.15.
shift (bool, optional): Whether to shift the schedule. Defaults to True.
Return:
List[float]: timesteps schedule
"""
# extra step for zero # extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1) timesteps = torch.linspace(1, 0, num_steps + 1)
@@ -301,11 +381,63 @@ def get_schedule(
return timesteps.tolist() return timesteps.tolist()
def denoise(
model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0
):
"""
Denoise an image using the NextDiT model.
Args:
model (lumina_models.NextDiT): The NextDiT model instance.
img (Tensor): The input image tensor.
txt (Tensor): The input text tensor.
txt_mask (Tensor): The input text mask tensor.
timesteps (List[float]): A list of timesteps for the denoising process.
guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0.
Returns:
img (Tensor): Denoised tensor
"""
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# model.prepare_block_swap_before_forward()
# block_samples = None
# block_single_samples = None
pred = model.forward_with_cfg(
x=img, # image latents (B, C, H, W)
t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=txt, # Gemma2的hidden states作为caption features
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
cfg_scale=guidance,
)
img = img + (t_prev - t_curr) * pred
# model.prepare_block_swap_before_forward()
return img
# endregion # endregion
# region train # region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): def get_sigmas(
noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32
) -> Tensor:
"""
Get sigmas for timesteps
Args:
noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance.
timesteps (Tensor): A tensor of timesteps for the denoising process.
device (torch.device): The device on which the tensors are stored.
n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4.
dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32.
Returns:
sigmas (Tensor): The sigmas tensor.
"""
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device) schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device) timesteps = timesteps.to(device)
@@ -320,11 +452,22 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
def compute_density_for_timestep_sampling( def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
): ):
"""Compute the density for sampling the timesteps when doing SD3 training. """
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
Args:
weighting_scheme (str): The weighting scheme to use.
batch_size (int): The batch size for the sampling process.
logit_mean (float, optional): The mean of the logit distribution. Defaults to None.
logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None.
mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None.
Returns:
u (Tensor): The sampled timesteps.
""" """
if weighting_scheme == "logit_normal": if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
@@ -338,12 +481,19 @@ def compute_density_for_timestep_sampling(
return u return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor:
"""Computes loss weighting scheme for SD3 training. """Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1. SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
Args:
weighting_scheme (str): The weighting scheme to use.
sigmas (Tensor, optional): The sigmas tensor. Defaults to None.
Returns:
u (Tensor): The sampled timesteps.
""" """
if weighting_scheme == "sigma_sqrt": if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float() weighting = (sigmas**-2.0).float()
@@ -355,9 +505,24 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting return weighting
def get_noisy_model_input_and_timesteps( def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]:
args, noise_scheduler, latents, noise, device, dtype """
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
bsz, _, h, w = latents.shape bsz, _, h, w = latents.shape
sigmas = None sigmas = None
@@ -412,7 +577,21 @@ def get_noisy_model_input_and_timesteps(
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): def apply_model_prediction_type(
args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Apply model prediction type to the model prediction and the sigmas.
Args:
args (argparse.Namespace): Arguments.
model_pred (Tensor): Model prediction.
noisy_model_input (Tensor): Noisy model input.
sigmas (Tensor): Sigmas.
Return:
Tuple[Tensor, Optional[Tensor]]:
"""
weighting = None weighting = None
if args.model_prediction_type == "raw": if args.model_prediction_type == "raw":
pass pass
@@ -433,10 +612,22 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
def save_models( def save_models(
ckpt_path: str, ckpt_path: str,
lumina: lumina_models.NextDiT, lumina: lumina_models.NextDiT,
sai_metadata: Optional[dict], sai_metadata: Dict[str, Any],
save_dtype: Optional[torch.dtype] = None, save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False, use_mem_eff_save: bool = False,
): ):
"""
Save the model to the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
lumina (lumina_models.NextDiT): NextDIT model.
sai_metadata (Optional[dict]): Metadata for the SAI model.
save_dtype (Optional[torch.dtype]): Data
Return:
None
"""
state_dict = {} state_dict = {}
def update_sd(prefix, sd): def update_sd(prefix, sd):
@@ -458,7 +649,9 @@ def save_lumina_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
): ):
def sd_saver(ckpt_file, epoch_no, global_step): def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2"
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -469,15 +662,29 @@ def save_lumina_model_on_train_end(
def save_lumina_model_on_epoch_end_or_stepwise( def save_lumina_model_on_epoch_end_or_stepwise(
args: argparse.Namespace, args: argparse.Namespace,
on_epoch_end: bool, on_epoch_end: bool,
accelerator, accelerator: Accelerator,
save_dtype: torch.dtype, save_dtype: torch.dtype,
epoch: int, epoch: int,
num_train_epochs: int, num_train_epochs: int,
global_step: int, global_step: int,
lumina: lumina_models.NextDiT, lumina: lumina_models.NextDiT,
): ):
def sd_saver(ckpt_file, epoch_no, global_step): """
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") Save the model to the checkpoint path.
Args:
args (argparse.Namespace): Arguments.
save_dtype (torch.dtype): Data type.
epoch (int): Epoch.
global_step (int): Global step.
lumina (lumina_models.NextDiT): NextDIT model.
Return:
None
"""
def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common( train_util.save_sd_model_on_epoch_end_or_stepwise_common(

View File

@@ -11,23 +11,33 @@ from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import lumina_models, flux_models from library import lumina_models, flux_models
from library.utils import load_safetensors from library.utils import load_safetensors
import logging
setup_logging()
logger = logging.getLogger(__name__)
MODEL_VERSION_LUMINA_V2 = "lumina2" MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model( def load_lumina_model(
ckpt_path: str, ckpt_path: str,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[str, torch.device], device: torch.device,
disable_mmap: bool = False, disable_mmap: bool = False,
): ):
"""
Load the Lumina model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina") logger.info("Building Lumina")
with torch.device("meta"): with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
@@ -46,6 +56,18 @@ def load_ae(
device: Union[str, torch.device], device: Union[str, torch.device],
disable_mmap: bool = False, disable_mmap: bool = False,
) -> flux_models.AutoEncoder: ) -> flux_models.AutoEncoder:
"""
Load the AutoEncoder model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
ae (flux_models.AutoEncoder): The loaded model.
"""
logger.info("Building AutoEncoder") logger.info("Building AutoEncoder")
with torch.device("meta"): with torch.device("meta"):
# dev and schnell have the same AE params # dev and schnell have the same AE params
@@ -67,6 +89,19 @@ def load_gemma2(
disable_mmap: bool = False, disable_mmap: bool = False,
state_dict: Optional[dict] = None, state_dict: Optional[dict] = None,
) -> Gemma2Model: ) -> Gemma2Model:
"""
Load the Gemma2 model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
Returns:
gemma2 (Gemma2Model): The loaded model
"""
logger.info("Building Gemma2") logger.info("Building Gemma2")
GEMMA2_CONFIG = { GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b", "_name_or_path": "google/gemma-2-2b",

View File

@@ -130,11 +130,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return False return False
if "input_ids" not in npz: if "input_ids" not in npz:
return False return False
if "apply_gemma2_attn_mask" not in npz:
return False
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
if not npz_apply_gemma2_attn_mask:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -142,11 +137,17 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
data = np.load(npz_path) data = np.load(npz_path)
hidden_state = data["hidden_state"] hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"] attention_mask = data["attention_mask"]
input_ids = data["input_ids"] input_ids = data["input_ids"]
return [hidden_state, attention_mask, input_ids] return [hidden_state, input_ids, attention_mask]
def cache_batch_outputs( def cache_batch_outputs(
self, self,
@@ -193,8 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
info.text_encoder_outputs_npz, info.text_encoder_outputs_npz,
hidden_state=hidden_state_i, hidden_state=hidden_state_i,
attention_mask=attention_mask_i, attention_mask=attention_mask_i,
input_ids=input_ids_i, input_ids=input_ids_i
apply_gemma2_attn_mask=True
) )
else: else:
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]

View File

@@ -2,9 +2,10 @@ import argparse
import copy import copy
import math import math
import random import random
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import torch import torch
from torch import Tensor
from accelerate import Accelerator from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex from library.device_utils import clean_memory_on_device, init_ipex
@@ -165,36 +166,31 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
) )
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = ( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
)
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
strategy_base.TextEncodingStrategy.get_strategy() assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
)
prompts = train_util.load_prompts(args.sample_prompts) sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = ( sample_prompts_te_outputs = (
{} {}
) # key: prompt, value: text encoder outputs ) # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad(): with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts: for prompt_dict in sample_prompts:
for p in [ prompts = [prompt_dict.get("prompt", ""),
prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]
prompt_dict.get("negative_prompt", ""), logger.info(
]: f"cache Text Encoder outputs for prompt: {prompts[0]}"
if p not in sample_prompts_te_outputs: )
logger.info( tokens_and_masks = tokenize_strategy.tokenize(prompts)
f"cache Text Encoder outputs for prompt: {p}" sample_prompts_te_outputs[prompts[0]] = (
) text_encoding_strategy.encode_tokens(
tokens_and_masks = tokenize_strategy.tokenize(p) tokenize_strategy,
sample_prompts_te_outputs[p] = ( text_encoders,
text_encoding_strategy.encode_tokens( tokens_and_masks,
tokenize_strategy, )
text_encoders, )
tokens_and_masks,
args.apply_t5_attn_mask,
)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -220,7 +216,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
epoch, epoch,
global_step, global_step,
device, device,
ae, vae,
tokenizer, tokenizer,
text_encoder, text_encoder,
lumina, lumina,
@@ -231,7 +227,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
epoch, epoch,
global_step, global_step,
lumina, lumina,
ae, vae,
self.get_models_for_text_encoding(args, accelerator, text_encoder), self.get_models_for_text_encoding(args, accelerator, text_encoder),
self.sample_prompts_te_outputs, self.sample_prompts_te_outputs,
) )
@@ -258,12 +254,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def get_noise_pred_and_target( def get_noise_pred_and_target(
self, self,
args, args,
accelerator, accelerator: Accelerator,
noise_scheduler, noise_scheduler,
latents, latents,
batch, batch,
text_encoder_conds, text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
unet: lumina_models.NextDiT, dit: lumina_models.NextDiT,
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
@@ -296,7 +292,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
with torch.set_grad_enabled(is_train), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask) # NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = unet( model_pred = dit(
x=img, # image latents (B, C, H, W) x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
@@ -341,7 +337,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
network.set_multiplier(0.0) network.set_multiplier(0.0)
with torch.no_grad(): with torch.no_grad():
model_pred_prior = call_dit( model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices], img=noisy_model_input[diff_output_pr_indices],
gemma2_hidden_states=gemma2_hidden_states[ gemma2_hidden_states=gemma2_hidden_states[
diff_output_pr_indices diff_output_pr_indices
], ],
@@ -350,9 +346,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
) )
network.set_multiplier(1.0) network.set_multiplier(1.0)
model_pred_prior = lumina_util.unpack_latents( # model_pred_prior = lumina_util.unpack_latents(
model_pred_prior, packed_latent_height, packed_latent_width # model_pred_prior, packed_latent_height, packed_latent_width
) # )
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
args, args,
model_pred_prior, model_pred_prior,
@@ -404,7 +400,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return super().prepare_unet_with_accelerator(args, accelerator, unet) return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device # if we doesn't swap blocks, we can move the model to device
nextdit: lumina_models.Nextdit = unet nextdit = unet
assert isinstance(nextdit, lumina_models.NextDiT)
nextdit = accelerator.prepare( nextdit = accelerator.prepare(
nextdit, device_placement=[not self.is_swapping_blocks] nextdit, device_placement=[not self.is_swapping_blocks]
) )