From d154e76c457a526d8af0853c92edab98cade22f6 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 12 Feb 2025 16:30:05 +0800 Subject: [PATCH] init --- library/lumina_models.py | 1144 ++++++++++++++++++++++++++++++++++ library/lumina_train_util.py | 554 ++++++++++++++++ library/lumina_util.py | 194 ++++++ library/sai_model_spec.py | 12 + library/strategy_lumina.py | 275 ++++++++ library/train_util.py | 2 + lumina_train_network.py | 192 ++++++ 7 files changed, 2373 insertions(+) create mode 100644 library/lumina_models.py create mode 100644 library/lumina_train_util.py create mode 100644 library/lumina_util.py create mode 100644 library/strategy_lumina.py create mode 100644 lumina_train_network.py diff --git a/library/lumina_models.py b/library/lumina_models.py new file mode 100644 index 00000000..43b1e9c6 --- /dev/null +++ b/library/lumina_models.py @@ -0,0 +1,1144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +from typing import List, Optional, Tuple +from dataclasses import dataclass + +from flash_attn import flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ImportError: + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + +@dataclass +class LuminaParams: + """Parameters for Lumina model configuration""" + patch_size: int = 2 + dim: int = 2592 + n_layers: int = 30 + n_heads: int = 24 + n_kv_heads: int = 8 + axes_dims: List[int] = None + axes_lens: List[int] = None + qk_norm: bool = False, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + scaling_factor: float = 1.0, + cap_feat_dim: int = 32, + + def __post_init__(self): + if self.axes_dims is None: + self.axes_dims = [36, 36, 36] + if self.axes_lens is None: + self.axes_lens = [300, 512, 512] + + @classmethod + def get_2b_config(cls) -> "LuminaParams": + """Returns the configuration for the 2B parameter model""" + return cls( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512] + ) + + @classmethod + def get_7b_config(cls) -> "LuminaParams": + """Returns the configuration for the 7B parameter model""" + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512] + ) + + +############################################################################# +# RMSNorm # +############################################################################# + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + + +############################################################################# +# Embedding Layers for Timesteps and Class Labels # +############################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + hidden_size, + hidden_size, + bias=True, + ), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.zeros_(self.mlp[0].bias) + nn.init.normal_(self.mlp[2].weight, std=0.02) + nn.init.zeros_(self.mlp[2].bias) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + + +############################################################################# +# Core NextDiT Model # +############################################################################# + + +class JointAttention(nn.Module): + """Multi-head attention module.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + ): + """ + Initialize the Attention module. + + Args: + dim (int): Number of input dimensions. + n_heads (int): Number of heads. + n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + + """ + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = nn.Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.qkv.weight) + + self.out = nn.Linear( + n_heads * self.head_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.out.weight) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = self.k_norm = nn.Identity() + + @staticmethod + def apply_rotary_emb( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + with torch.amp.autocast("cuda",enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + # copied from huggingface modeling_llama.py + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape( + batch_size * kv_seq_len, self.n_local_heads, head_dim + ), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + + Args: + x: + x_mask: + freqs_cis: + + Returns: + + """ + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if dtype in [torch.float16, torch.bfloat16]: + # begin var_len flash attn + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(xq, xk, xv, x_mask, seqlen) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) + # end var_len_flash_attn + + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + output = ( + F.scaled_dot_product_attention( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool() + .view(bsz, 1, 1, seqlen) + .expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + + return self.out(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + """ + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w1.weight) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w2.weight) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w3.weight) + + # @torch.compile + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class JointTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + ) -> None: + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + dim (int): Embedding dimension of the input features. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): + ffn_dim_multiplier (float): + norm_eps (float): + + """ + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(dim, 1024), + 4 * dim, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and + feedforward layers. + + """ + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( + adaln_input + ).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + freqs_cis, + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + else: + assert adaln_input is None + x = x + self.attention_norm2( + self.attention( + self.attention_norm1(x), + x_mask, + freqs_cis, + ) + ) + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of NextDiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + ) + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(hidden_size, 1024), + hidden_size, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + ): + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.freqs_cis = NextDiT.precompute_freqs_cis( + self.axes_dims, self.axes_lens, theta=self.theta + ) + + def __call__(self, ids: torch.Tensor): + self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + # import torch.distributed as dist + # if not dist.is_initialized() or dist.get_rank() == 0: + # import pdb + # pdb.set_trace() + index = ( + ids[:, :, i : i + 1] + .repeat(1, 1, self.freqs_cis[i].shape[-1]) + .to(torch.int64) + ) + result.append( + torch.gather( + self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), + dim=1, + index=index, + ) + ) + return torch.cat(result, dim=-1) + + +class NextDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + ) + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + dim, + bias=True, + ), + ) + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + # nn.init.zeros_(self.cap_embedder[1].weight) + nn.init.zeros_(self.cap_embedder[1].bias) + + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + ) + for layer_id in range(n_layers) + ] + ) + self.norm_final = RMSNorm(dim, eps=norm_eps) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.dim = dim + self.n_heads = n_heads + + def unpatchify( + self, + x: torch.Tensor, + img_size: List[Tuple[int, int]], + cap_size: List[int], + return_tensor=False, + ) -> List[torch.Tensor]: + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + pH = pW = self.patch_size + imgs = [] + for i in range(x.size(0)): + H, W = img_size[i] + begin = cap_size[i] + end = begin + (H // pH) * (W // pW) + imgs.append( + x[i][begin:end] + .view(H // pH, W // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + + if return_tensor: + imgs = torch.stack(imgs, dim=0) + return imgs + + def patchify_and_embed( + self, + x: List[torch.Tensor] | torch.Tensor, + cap_feats: torch.Tensor, + cap_mask: torch.Tensor, + t: torch.Tensor, + ) -> Tuple[ + torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor + ]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + + max_seq_len = max( + ( + cap_len + img_len + for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len) + ) + ) + max_cap_len = max(l_effective_cap_len) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros( + bsz, max_seq_len, 3, dtype=torch.int32, device=device + ) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // pH, W // pW + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange( + cap_len, dtype=torch.int32, device=device + ) + position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + row_ids = ( + torch.arange(H_tokens, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, W_tokens) + .flatten() + ) + col_ids = ( + torch.arange(W_tokens, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(H_tokens, 1) + .flatten() + ) + position_ids[i, cap_len : cap_len + img_len, 1] = row_ids + position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + + freqs_cis = self.rope_embedder(position_ids) + + # build freqs_cis for cap and image individually + cap_freqs_cis_shape = list(freqs_cis.shape) + # cap_freqs_cis_shape[1] = max_cap_len + cap_freqs_cis_shape[1] = cap_feats.shape[1] + cap_freqs_cis = torch.zeros( + *cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype + ) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros( + *img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype + ) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + + # refine context + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + # refine image + flat_x = [] + for i in range(bsz): + img = x[i] + C, H, W = img.size() + img = ( + img.view(C, H // pH, pH, W // pW, pW) + .permute(1, 3, 2, 4, 0) + .flatten(2) + .flatten(0, 1) + ) + flat_x.append(img) + x = flat_x + padded_img_embed = torch.zeros( + bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype + ) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + for i in range(bsz): + padded_img_embed[i, : l_effective_img_len[i]] = x[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + padded_img_embed = self.x_embedder(padded_img_embed) + for layer in self.noise_refiner: + padded_img_embed = layer( + padded_img_embed, padded_img_mask, img_freqs_cis, t + ) + + mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + padded_full_embed = torch.zeros( + bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype + ) + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + + mask[i, : cap_len + img_len] = True + padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] + padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[ + i, :img_len + ] + + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + def forward(self, x, t, cap_feats, cap_mask): + """ + Forward pass of NextDiT. + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of text tokens/features + """ + + # import torch.distributed as dist + # if not dist.is_initialized() or dist.get_rank() == 0: + # import pdb + # pdb.set_trace() + # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt") + t = self.t_embedder(t) # (N, D) + adaln_input = t + + cap_feats = self.cap_embedder( + cap_feats + ) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x_is_tensor = isinstance(x, torch.Tensor) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + freqs_cis = freqs_cis.to(x.device) + + for layer in self.layers: + x = layer(x, mask, freqs_cis, adaln_input) + + x = self.final_layer(x, adaln_input) + x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor) + + return x + + def forward_with_cfg( + self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1 + ): + """ + Forward pass of NextDiT, but also batches the unconditional forward pass + for classifier-free guidance. + """ + # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + if t[0] < cfg_trunc: + combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] + model_out = self.forward( + combined, t, cap_feats, cap_mask + ) # [2, 16, 128, 128] + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if float(renorm_cfg) > 0.0: + ori_pos_norm = torch.linalg.vector_norm( + cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True + ) + max_new_norm = ori_pos_norm * float(renorm_cfg) + new_pos_norm = torch.linalg.vector_norm( + half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True + ) + if new_pos_norm >= max_new_norm: + half_eps = half_eps * (max_new_norm / new_pos_norm) + else: + combined = half + model_out = self.forward( + combined, + t[: len(x) // 2], + cap_feats[: len(x) // 2], + cap_mask[: len(x) // 2], + ) + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + half_eps = eps + + output = torch.cat([half_eps, half_eps], dim=0) + return output + + @staticmethod + def precompute_freqs_cis( + dim: List[int], + end: List[int], + theta: float = 10000.0, + ): + """ + Precompute the frequency tensor for complex exponentials (cis) with + given dimensions. + + This function calculates a frequency tensor with complex exponentials + using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. + + Args: + dim (list): Dimension of the frequency tensor. + end (list): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. + Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex + exponentials. + """ + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / ( + theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) + ) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to( + torch.complex64 + ) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def parameter_count(self) -> int: + total_params = 0 + + def _recursive_count_params(module): + nonlocal total_params + for param in module.parameters(recurse=False): + total_params += param.numel() + for submodule in module.children(): + _recursive_count_params(submodule) + + _recursive_count_params(self) + return total_params + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + +############################################################################# +# NextDiT Configs # +############################################################################# + + +def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): + if params is None: + params = LuminaParams.get_2b_config() + + return NextDiT( + patch_size=params.patch_size, + dim=params.dim, + n_layers=params.n_layers, + n_heads=params.n_heads, + n_kv_heads=params.n_kv_heads, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + qk_norm=params.qk_norm, + ffn_dim_multiplier=params.ffn_dim_multiplier, + norm_eps=params.norm_eps, + scaling_factor=params.scaling_factor, + cap_feat_dim=params.cap_feat_dim, + **kwargs, + ) + + +def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2880, + n_layers=32, + n_heads=24, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=3840, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py new file mode 100644 index 00000000..d3edd262 --- /dev/null +++ b/library/lumina_train_util.py @@ -0,0 +1,554 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import AutoTokenizer, AutoModelForCausalLM +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import lumina_models, lumina_util, strategy_base, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + nextdit, + ae, + gemma2_model, + sample_prompts_gemma2_outputs, + prompt_replacement=None, + controlnet=None +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap nextdit and gemma2_model + nextdit = accelerator.unwrap_model(nextdit) + if gemma2_model is not None: + gemma2_model = accelerator.unwrap_model(gemma2_model) + # if controlnet is not None: + # controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + # controlnet +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + gemma2_conds = [] + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + print(f"Using cached Gemma2 outputs for prompt: {prompt}") + if gemma2_model is not None: + print(f"Encoding prompt with Gemma2: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_gemma2_attn_mask option + encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + # if gemma2_conds is not cached, use encoded_gemma2_conds + if len(gemma2_conds) == 0: + gemma2_conds = encoded_gemma2_conds + else: + # if encoded_gemma2_conds is not None, update cached gemma2_conds + for i in range(len(encoded_gemma2_conds)): + if encoded_gemma2_conds[i] is not None: + gemma2_conds[i] = encoded_gemma2_conds[i] + + # Unpack Gemma2 outputs + gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) + img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None + + # if controlnet_image is not None: + # controlnet_image = Image.open(controlnet_image).convert("RGB") + # controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + + x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, h, w = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "nextdit_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + lumina: lumina_models.NextDiT, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", lumina.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_lumina_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_lumina_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + lumina: lumina_models.NextDiT, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_lumina_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--gemma2", + type=str, + help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=None, + help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" + " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_gemma2_attn_mask", + action="store_true", + help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the NextDIT.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py new file mode 100644 index 00000000..990f8c68 --- /dev/null +++ b/library/lumina_util.py @@ -0,0 +1,194 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import Gemma2Config, Gemma2Model + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import lumina_models, flux_models +from library.utils import load_safetensors + +MODEL_VERSION_LUMINA_V2 = "lumina2" + +def load_lumina_model( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> lumina_models.Lumina: + logger.info("Building Lumina") + with torch.device("meta"): + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + state_dict = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = model.load_state_dict(state_dict, strict=False, assign=True) + logger.info(f"Loaded Lumina: {info}") + return model + +def load_ae( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_gemma2( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Gemma2Model: + logger.info("Building Gemma2") + GEMMA2_CONFIG = { + "_name_or_path": "google/gemma-2b", + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "head_dim": 256, + "hidden_act": "gelu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 8192, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0.dev0", + "use_cache": true, + "vocab_size": 256000 + } + config = Gemma2Config(**GEMMA2_CONFIG) + with init_empty_weights(): + gemma2 = Gemma2Model._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = gemma2.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Gemma2: {info}") + return gemma2 + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + +DIFFUSERS_TO_ALPHA_VLLM_MAP = { + # Embedding layers + "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], + "cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight", + "cap_embedder.1.bias": "text_embedder.1.bias", + "x_embedder.weight": "patch_embedder.proj.weight", + "x_embedder.bias": "patch_embedder.proj.bias", + # Attention modulation + "layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight", + "layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias", + # Final layers + "final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight", + "final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias", + "final_layer.linear.weight": "final_linear.weight", + "final_layer.linear.bias": "final_linear.bias", + # Noise refiner + "noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight", + "noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias", + "noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight", + "noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight", + # Time embedding + "t_embedder.mlp.0.weight": "time_embedder.0.weight", + "t_embedder.mlp.0.bias": "time_embedder.0.bias", + "t_embedder.mlp.2.weight": "time_embedder.2.weight", + "t_embedder.mlp.2.bias": "time_embedder.2.bias", + # Context attention + "context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight", + "context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight", + # Normalization + "layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight", + "layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight", + # FFN + "layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight", + "layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight", + "layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight", +} + + +def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: + """Convert Diffusers checkpoint to Alpha-VLLM format""" + logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") + new_sd = {} + + for key, value in sd.items(): + new_key = key + for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + if "()." in pattern: + for block_idx in range(num_double_blocks): + if str(block_idx) in key: + converted = pattern.replace("()", str(block_idx)) + new_key = key.replace( + converted, replacement.replace("()", str(block_idx)) + ) + break + + if new_key == key: + logger.debug(f"Unmatched key in conversion: {key}") + new_sd[new_key] = value + + logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") + return new_sd diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8896c047..1e97c9cd 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -61,6 +61,8 @@ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. # ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" +ARCH_LUMINA_2 = "lumina-2" +ARCH_LUMINA_UNKNOWN = "lumina" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -123,6 +126,7 @@ def build_metadata( clip_skip: Optional[int] = None, sd3: Optional[str] = None, flux: Optional[str] = None, + lumina: Optional[str] = None, ): """ sd3: only supports "m", flux: only supports "dev" @@ -146,6 +150,11 @@ def build_metadata( arch = ARCH_FLUX_1_DEV else: arch = ARCH_FLUX_1_UNKNOWN + elif lumina is not None: + if lumina == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -167,6 +176,9 @@ def build_metadata( if flux is not None: # Flux impl = IMPL_FLUX + elif lumina is not None: + # Lumina + impl = IMPL_LUMINA elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py new file mode 100644 index 00000000..622c019a --- /dev/null +++ b/library/strategy_lumina.py @@ -0,0 +1,275 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import AutoTokenizer, AutoModel +from library import train_util +from library.strategy_base import ( + LatentsCachingStrategy, + TokenizeStrategy, + TextEncodingStrategy, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +GEMMA_ID = "google/gemma-2-2b" + + +class LuminaTokenizeStrategy(TokenizeStrategy): + def __init__( + self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + GEMMA_ID, cache_dir=tokenizer_cache_dir + ) + self.tokenizer.padding_side = "right" + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + encodings = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + return_tensors="pt", + truncation=True, + ) + return [encodings.input_ids] + + def tokenize_with_weights( + self, text: str | List[str] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # Gemma doesn't support weighted prompts, return uniform weights + tokens = self.tokenize(text) + weights = [torch.ones_like(t) for t in tokens] + return tokens, weights + + +class LuminaTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None: + super().__init__() + self.apply_gemma2_attn_mask = apply_gemma2_attn_mask + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_gemma2_attn_mask: Optional[bool] = None, + ) -> List[torch.Tensor]: + + if apply_gemma2_attn_mask is None: + apply_gemma2_attn_mask = self.apply_gemma2_attn_mask + + text_encoder = models[0] + input_ids = tokens[0].to(text_encoder.device) + + attention_mask = None + position_ids = None + if apply_gemma2_attn_mask: + # Create attention mask (1 for non-padding, 0 for padding) + attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to( + text_encoder.device + ) + + # Create position IDs + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + with torch.no_grad(): + outputs = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True, + ) + # Get the last hidden state + hidden_states = outputs.last_hidden_state + + return [hidden_states] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + # For simplicity, use uniform weighting + return self.encode_tokens(tokenize_strategy, models, tokens_list) + + +class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_gemma2_attn_mask: bool = False, + ) -> None: + super().__init__( + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + is_partial, + ) + self.apply_gemma2_attn_mask = apply_gemma2_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return ( + os.path.splitext(image_abs_path)[0] + + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + ) + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state = data["hidden_state"] + return [hidden_state] + + def cache_batch_outputs( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + infos: List, + ): + lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( + text_encoding_strategy + ) + captions = [info.caption for info in infos] + + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights( + captions + ) + with torch.no_grad(): + hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + )[0] + else: + tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state = lumina_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + )[0] + + if hidden_state.dtype == torch.bfloat16: + hidden_state = hidden_state.float() + + hidden_state = hidden_state.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state_i = hidden_state[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state=hidden_state_i, + ) + else: + info.text_encoder_outputs = [hidden_state_i] + + +class LuminaLatentsCachingStrategy(LatentsCachingStrategy): + LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path( + self, absolute_path: str, image_size: Tuple[int, int] + ) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + ): + return self._default_is_disk_cached_latents_expected( + 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: + return self._default_load_latents_from_disk( + 8, npz_path, bucket_reso + ) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents( + self, + vae, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + ): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, + vae_device, + vae_dtype, + image_infos, + flip_aug, + alpha_mask, + random_crop, + multi_resolution=True, + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a99..34ffe22b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3463,6 +3463,7 @@ def get_sai_model_spec( is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, flux: str = None, + lumina: str = None, ): timestamp = time.time() @@ -3498,6 +3499,7 @@ def get_sai_model_spec( clip_skip=args.clip_skip, # None or int sd3=sd3, flux=flux, + lumina=lumina, ) return metadata diff --git a/lumina_train_network.py b/lumina_train_network.py new file mode 100644 index 00000000..40b84e14 --- /dev/null +++ b/lumina_train_network.py @@ -0,0 +1,192 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional, Union + +import torch +from accelerate import Accelerator + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import train_network +from library import ( + lumina_models, + flux_train_utils, + lumina_util, + lumina_train_util, + sd3_train_utils, + strategy_base, + strategy_lumina, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LuminaNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group, val_dataset_group): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + + if ( + args.cache_text_encoder_outputs_to_disk + and not args.cache_text_encoder_outputs + ): + logger.warning("Enabling cache_text_encoder_outputs due to disk caching") + args.cache_text_encoder_outputs = True + + train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) + + self.train_gemma2 = not args.network_train_unet_only + + def load_target_model(self, args, weight_dtype, accelerator): + loading_dtype = None if args.fp8 else weight_dtype + + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + "cpu", + disable_mmap=args.disable_mmap_load_safetensors, + ) + + # if args.blocks_to_swap: + # logger.info(f'Enabling block swap: {args.blocks_to_swap}') + # model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # self.is_swapping_blocks = True + + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + + return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model + + def get_tokenize_strategy(self, args): + return strategy_lumina.LuminaTokenizeStrategy( + args.gemma2_max_token_length, args.tokenizer_cache_dir + ) + + def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + return strategy_lumina.LuminaLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, False + ) + + def get_text_encoding_strategy(self, args): + return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask) + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_gemma2] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_gemma2, + apply_gemma2_attn_mask=args.apply_gemma2_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, + args, + accelerator: Accelerator, + unet, + vae, + text_encoders, + dataset, + weight_dtype, + ): + for text_encoder in text_encoders: + text_encoder_outputs_caching_strategy = ( + self.get_text_encoder_outputs_caching_strategy(args) + ) + if text_encoder_outputs_caching_strategy is not None: + text_encoder_outputs_caching_strategy.cache_batch_outputs( + self.get_tokenize_strategy(args), + [text_encoder], + self.get_text_encoding_strategy(args), + dataset, + ) + + def sample_images( + self, + accelerator, + args, + epoch, + global_step, + device, + ae, + tokenizer, + text_encoder, + lumina, + ): + lumina_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + lumina, + ae, + self.get_models_for_text_encoding(args, accelerator, text_encoder), + self.sample_prompts_te_outputs, + ) + + # Remaining methods maintain similar structure to flux implementation + # with Lumina-specific model calls and strategies + + def get_noise_scheduler( + self, args: argparse.Namespace, device: torch.device + ) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=args.discrete_flow_shift + ) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + # not sure, they use same flux vae + def shift_scale_latents(self, args, latents): + return latents + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + lumina_train_utils.add_lumina_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = LuminaNetworkTrainer() + trainer.train(args)