From d154e76c457a526d8af0853c92edab98cade22f6 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 12 Feb 2025 16:30:05 +0800 Subject: [PATCH 01/73] init --- library/lumina_models.py | 1144 ++++++++++++++++++++++++++++++++++ library/lumina_train_util.py | 554 ++++++++++++++++ library/lumina_util.py | 194 ++++++ library/sai_model_spec.py | 12 + library/strategy_lumina.py | 275 ++++++++ library/train_util.py | 2 + lumina_train_network.py | 192 ++++++ 7 files changed, 2373 insertions(+) create mode 100644 library/lumina_models.py create mode 100644 library/lumina_train_util.py create mode 100644 library/lumina_util.py create mode 100644 library/strategy_lumina.py create mode 100644 lumina_train_network.py diff --git a/library/lumina_models.py b/library/lumina_models.py new file mode 100644 index 00000000..43b1e9c6 --- /dev/null +++ b/library/lumina_models.py @@ -0,0 +1,1144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +from typing import List, Optional, Tuple +from dataclasses import dataclass + +from flash_attn import flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except ImportError: + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + +@dataclass +class LuminaParams: + """Parameters for Lumina model configuration""" + patch_size: int = 2 + dim: int = 2592 + n_layers: int = 30 + n_heads: int = 24 + n_kv_heads: int = 8 + axes_dims: List[int] = None + axes_lens: List[int] = None + qk_norm: bool = False, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + scaling_factor: float = 1.0, + cap_feat_dim: int = 32, + + def __post_init__(self): + if self.axes_dims is None: + self.axes_dims = [36, 36, 36] + if self.axes_lens is None: + self.axes_lens = [300, 512, 512] + + @classmethod + def get_2b_config(cls) -> "LuminaParams": + """Returns the configuration for the 2B parameter model""" + return cls( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512] + ) + + @classmethod + def get_7b_config(cls) -> "LuminaParams": + """Returns the configuration for the 7B parameter model""" + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512] + ) + + +############################################################################# +# RMSNorm # +############################################################################# + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + + +############################################################################# +# Embedding Layers for Timesteps and Class Labels # +############################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + hidden_size, + hidden_size, + bias=True, + ), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.zeros_(self.mlp[0].bias) + nn.init.normal_(self.mlp[2].weight, std=0.02) + nn.init.zeros_(self.mlp[2].bias) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + + +############################################################################# +# Core NextDiT Model # +############################################################################# + + +class JointAttention(nn.Module): + """Multi-head attention module.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + ): + """ + Initialize the Attention module. + + Args: + dim (int): Number of input dimensions. + n_heads (int): Number of heads. + n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + + """ + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = nn.Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.qkv.weight) + + self.out = nn.Linear( + n_heads * self.head_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.out.weight) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = self.k_norm = nn.Identity() + + @staticmethod + def apply_rotary_emb( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + with torch.amp.autocast("cuda",enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) + + # copied from huggingface modeling_llama.py + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape( + batch_size * kv_seq_len, self.n_local_heads, head_dim + ), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + """ + + Args: + x: + x_mask: + freqs_cis: + + Returns: + + """ + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if dtype in [torch.float16, torch.bfloat16]: + # begin var_len flash attn + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(xq, xk, xv, x_mask, seqlen) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) + # end var_len_flash_attn + + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + output = ( + F.scaled_dot_product_attention( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool() + .view(bsz, 1, 1, seqlen) + .expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + + return self.out(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + """ + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w1.weight) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w2.weight) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w3.weight) + + # @torch.compile + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class JointTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + ) -> None: + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + dim (int): Embedding dimension of the input features. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): + ffn_dim_multiplier (float): + norm_eps (float): + + """ + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(dim, 1024), + 4 * dim, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and + feedforward layers. + + """ + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( + adaln_input + ).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + freqs_cis, + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + else: + assert adaln_input is None + x = x + self.attention_norm2( + self.attention( + self.attention_norm1(x), + x_mask, + freqs_cis, + ) + ) + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of NextDiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + ) + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(hidden_size, 1024), + hidden_size, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + ): + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.freqs_cis = NextDiT.precompute_freqs_cis( + self.axes_dims, self.axes_lens, theta=self.theta + ) + + def __call__(self, ids: torch.Tensor): + self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + # import torch.distributed as dist + # if not dist.is_initialized() or dist.get_rank() == 0: + # import pdb + # pdb.set_trace() + index = ( + ids[:, :, i : i + 1] + .repeat(1, 1, self.freqs_cis[i].shape[-1]) + .to(torch.int64) + ) + result.append( + torch.gather( + self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), + dim=1, + index=index, + ) + ) + return torch.cat(result, dim=-1) + + +class NextDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + ) + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + dim, + bias=True, + ), + ) + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + # nn.init.zeros_(self.cap_embedder[1].weight) + nn.init.zeros_(self.cap_embedder[1].bias) + + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + ) + for layer_id in range(n_layers) + ] + ) + self.norm_final = RMSNorm(dim, eps=norm_eps) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.dim = dim + self.n_heads = n_heads + + def unpatchify( + self, + x: torch.Tensor, + img_size: List[Tuple[int, int]], + cap_size: List[int], + return_tensor=False, + ) -> List[torch.Tensor]: + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + pH = pW = self.patch_size + imgs = [] + for i in range(x.size(0)): + H, W = img_size[i] + begin = cap_size[i] + end = begin + (H // pH) * (W // pW) + imgs.append( + x[i][begin:end] + .view(H // pH, W // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + + if return_tensor: + imgs = torch.stack(imgs, dim=0) + return imgs + + def patchify_and_embed( + self, + x: List[torch.Tensor] | torch.Tensor, + cap_feats: torch.Tensor, + cap_mask: torch.Tensor, + t: torch.Tensor, + ) -> Tuple[ + torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor + ]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + + max_seq_len = max( + ( + cap_len + img_len + for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len) + ) + ) + max_cap_len = max(l_effective_cap_len) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros( + bsz, max_seq_len, 3, dtype=torch.int32, device=device + ) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // pH, W // pW + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange( + cap_len, dtype=torch.int32, device=device + ) + position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + row_ids = ( + torch.arange(H_tokens, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, W_tokens) + .flatten() + ) + col_ids = ( + torch.arange(W_tokens, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(H_tokens, 1) + .flatten() + ) + position_ids[i, cap_len : cap_len + img_len, 1] = row_ids + position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + + freqs_cis = self.rope_embedder(position_ids) + + # build freqs_cis for cap and image individually + cap_freqs_cis_shape = list(freqs_cis.shape) + # cap_freqs_cis_shape[1] = max_cap_len + cap_freqs_cis_shape[1] = cap_feats.shape[1] + cap_freqs_cis = torch.zeros( + *cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype + ) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros( + *img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype + ) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + + # refine context + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + # refine image + flat_x = [] + for i in range(bsz): + img = x[i] + C, H, W = img.size() + img = ( + img.view(C, H // pH, pH, W // pW, pW) + .permute(1, 3, 2, 4, 0) + .flatten(2) + .flatten(0, 1) + ) + flat_x.append(img) + x = flat_x + padded_img_embed = torch.zeros( + bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype + ) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + for i in range(bsz): + padded_img_embed[i, : l_effective_img_len[i]] = x[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + padded_img_embed = self.x_embedder(padded_img_embed) + for layer in self.noise_refiner: + padded_img_embed = layer( + padded_img_embed, padded_img_mask, img_freqs_cis, t + ) + + mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + padded_full_embed = torch.zeros( + bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype + ) + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + + mask[i, : cap_len + img_len] = True + padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] + padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[ + i, :img_len + ] + + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + def forward(self, x, t, cap_feats, cap_mask): + """ + Forward pass of NextDiT. + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of text tokens/features + """ + + # import torch.distributed as dist + # if not dist.is_initialized() or dist.get_rank() == 0: + # import pdb + # pdb.set_trace() + # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt") + t = self.t_embedder(t) # (N, D) + adaln_input = t + + cap_feats = self.cap_embedder( + cap_feats + ) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x_is_tensor = isinstance(x, torch.Tensor) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + freqs_cis = freqs_cis.to(x.device) + + for layer in self.layers: + x = layer(x, mask, freqs_cis, adaln_input) + + x = self.final_layer(x, adaln_input) + x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor) + + return x + + def forward_with_cfg( + self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1 + ): + """ + Forward pass of NextDiT, but also batches the unconditional forward pass + for classifier-free guidance. + """ + # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + if t[0] < cfg_trunc: + combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] + model_out = self.forward( + combined, t, cap_feats, cap_mask + ) # [2, 16, 128, 128] + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if float(renorm_cfg) > 0.0: + ori_pos_norm = torch.linalg.vector_norm( + cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True + ) + max_new_norm = ori_pos_norm * float(renorm_cfg) + new_pos_norm = torch.linalg.vector_norm( + half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True + ) + if new_pos_norm >= max_new_norm: + half_eps = half_eps * (max_new_norm / new_pos_norm) + else: + combined = half + model_out = self.forward( + combined, + t[: len(x) // 2], + cap_feats[: len(x) // 2], + cap_mask[: len(x) // 2], + ) + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + half_eps = eps + + output = torch.cat([half_eps, half_eps], dim=0) + return output + + @staticmethod + def precompute_freqs_cis( + dim: List[int], + end: List[int], + theta: float = 10000.0, + ): + """ + Precompute the frequency tensor for complex exponentials (cis) with + given dimensions. + + This function calculates a frequency tensor with complex exponentials + using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. + + Args: + dim (list): Dimension of the frequency tensor. + end (list): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. + Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex + exponentials. + """ + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / ( + theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) + ) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to( + torch.complex64 + ) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def parameter_count(self) -> int: + total_params = 0 + + def _recursive_count_params(module): + nonlocal total_params + for param in module.parameters(recurse=False): + total_params += param.numel() + for submodule in module.children(): + _recursive_count_params(submodule) + + _recursive_count_params(self) + return total_params + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + +############################################################################# +# NextDiT Configs # +############################################################################# + + +def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): + if params is None: + params = LuminaParams.get_2b_config() + + return NextDiT( + patch_size=params.patch_size, + dim=params.dim, + n_layers=params.n_layers, + n_heads=params.n_heads, + n_kv_heads=params.n_kv_heads, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + qk_norm=params.qk_norm, + ffn_dim_multiplier=params.ffn_dim_multiplier, + norm_eps=params.norm_eps, + scaling_factor=params.scaling_factor, + cap_feat_dim=params.cap_feat_dim, + **kwargs, + ) + + +def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2880, + n_layers=32, + n_heads=24, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=3840, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py new file mode 100644 index 00000000..d3edd262 --- /dev/null +++ b/library/lumina_train_util.py @@ -0,0 +1,554 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import AutoTokenizer, AutoModelForCausalLM +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import lumina_models, lumina_util, strategy_base, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + nextdit, + ae, + gemma2_model, + sample_prompts_gemma2_outputs, + prompt_replacement=None, + controlnet=None +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap nextdit and gemma2_model + nextdit = accelerator.unwrap_model(nextdit) + if gemma2_model is not None: + gemma2_model = accelerator.unwrap_model(gemma2_model) + # if controlnet is not None: + # controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + nextdit, + gemma2_model, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_gemma2_outputs, + prompt_replacement, + # controlnet +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + gemma2_conds = [] + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + print(f"Using cached Gemma2 outputs for prompt: {prompt}") + if gemma2_model is not None: + print(f"Encoding prompt with Gemma2: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_gemma2_attn_mask option + encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + # if gemma2_conds is not cached, use encoded_gemma2_conds + if len(gemma2_conds) == 0: + gemma2_conds = encoded_gemma2_conds + else: + # if encoded_gemma2_conds is not None, update cached gemma2_conds + for i in range(len(encoded_gemma2_conds)): + if encoded_gemma2_conds[i] is not None: + gemma2_conds[i] = encoded_gemma2_conds[i] + + # Unpack Gemma2 outputs + gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) + img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None + + # if controlnet_image is not None: + # controlnet_image = Image.open(controlnet_image).convert("RGB") + # controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + + x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, h, w = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "nextdit_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + lumina: lumina_models.NextDiT, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", lumina.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_lumina_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_lumina_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + lumina: lumina_models.NextDiT, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_lumina_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--gemma2", + type=str, + help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=None, + help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" + " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_gemma2_attn_mask", + action="store_true", + help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the NextDIT.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py new file mode 100644 index 00000000..990f8c68 --- /dev/null +++ b/library/lumina_util.py @@ -0,0 +1,194 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import Gemma2Config, Gemma2Model + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import lumina_models, flux_models +from library.utils import load_safetensors + +MODEL_VERSION_LUMINA_V2 = "lumina2" + +def load_lumina_model( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> lumina_models.Lumina: + logger.info("Building Lumina") + with torch.device("meta"): + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + state_dict = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = model.load_state_dict(state_dict, strict=False, assign=True) + logger.info(f"Loaded Lumina: {info}") + return model + +def load_ae( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_gemma2( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Gemma2Model: + logger.info("Building Gemma2") + GEMMA2_CONFIG = { + "_name_or_path": "google/gemma-2b", + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "head_dim": 256, + "hidden_act": "gelu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 8192, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.0.dev0", + "use_cache": true, + "vocab_size": 256000 + } + config = Gemma2Config(**GEMMA2_CONFIG) + with init_empty_weights(): + gemma2 = Gemma2Model._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors( + ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ) + info = gemma2.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Gemma2: {info}") + return gemma2 + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + +DIFFUSERS_TO_ALPHA_VLLM_MAP = { + # Embedding layers + "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], + "cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight", + "cap_embedder.1.bias": "text_embedder.1.bias", + "x_embedder.weight": "patch_embedder.proj.weight", + "x_embedder.bias": "patch_embedder.proj.bias", + # Attention modulation + "layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight", + "layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias", + # Final layers + "final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight", + "final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias", + "final_layer.linear.weight": "final_linear.weight", + "final_layer.linear.bias": "final_linear.bias", + # Noise refiner + "noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight", + "noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias", + "noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight", + "noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight", + # Time embedding + "t_embedder.mlp.0.weight": "time_embedder.0.weight", + "t_embedder.mlp.0.bias": "time_embedder.0.bias", + "t_embedder.mlp.2.weight": "time_embedder.2.weight", + "t_embedder.mlp.2.bias": "time_embedder.2.bias", + # Context attention + "context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight", + "context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight", + # Normalization + "layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight", + "layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight", + # FFN + "layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight", + "layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight", + "layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight", +} + + +def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: + """Convert Diffusers checkpoint to Alpha-VLLM format""" + logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") + new_sd = {} + + for key, value in sd.items(): + new_key = key + for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + if "()." in pattern: + for block_idx in range(num_double_blocks): + if str(block_idx) in key: + converted = pattern.replace("()", str(block_idx)) + new_key = key.replace( + converted, replacement.replace("()", str(block_idx)) + ) + break + + if new_key == key: + logger.debug(f"Unmatched key in conversion: {key}") + new_sd[new_key] = value + + logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") + return new_sd diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8896c047..1e97c9cd 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -61,6 +61,8 @@ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. # ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" +ARCH_LUMINA_2 = "lumina-2" +ARCH_LUMINA_UNKNOWN = "lumina" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -123,6 +126,7 @@ def build_metadata( clip_skip: Optional[int] = None, sd3: Optional[str] = None, flux: Optional[str] = None, + lumina: Optional[str] = None, ): """ sd3: only supports "m", flux: only supports "dev" @@ -146,6 +150,11 @@ def build_metadata( arch = ARCH_FLUX_1_DEV else: arch = ARCH_FLUX_1_UNKNOWN + elif lumina is not None: + if lumina == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -167,6 +176,9 @@ def build_metadata( if flux is not None: # Flux impl = IMPL_FLUX + elif lumina is not None: + # Lumina + impl = IMPL_LUMINA elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py new file mode 100644 index 00000000..622c019a --- /dev/null +++ b/library/strategy_lumina.py @@ -0,0 +1,275 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import AutoTokenizer, AutoModel +from library import train_util +from library.strategy_base import ( + LatentsCachingStrategy, + TokenizeStrategy, + TextEncodingStrategy, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +GEMMA_ID = "google/gemma-2-2b" + + +class LuminaTokenizeStrategy(TokenizeStrategy): + def __init__( + self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + GEMMA_ID, cache_dir=tokenizer_cache_dir + ) + self.tokenizer.padding_side = "right" + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + encodings = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + return_tensors="pt", + truncation=True, + ) + return [encodings.input_ids] + + def tokenize_with_weights( + self, text: str | List[str] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # Gemma doesn't support weighted prompts, return uniform weights + tokens = self.tokenize(text) + weights = [torch.ones_like(t) for t in tokens] + return tokens, weights + + +class LuminaTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None: + super().__init__() + self.apply_gemma2_attn_mask = apply_gemma2_attn_mask + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_gemma2_attn_mask: Optional[bool] = None, + ) -> List[torch.Tensor]: + + if apply_gemma2_attn_mask is None: + apply_gemma2_attn_mask = self.apply_gemma2_attn_mask + + text_encoder = models[0] + input_ids = tokens[0].to(text_encoder.device) + + attention_mask = None + position_ids = None + if apply_gemma2_attn_mask: + # Create attention mask (1 for non-padding, 0 for padding) + attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to( + text_encoder.device + ) + + # Create position IDs + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + with torch.no_grad(): + outputs = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True, + ) + # Get the last hidden state + hidden_states = outputs.last_hidden_state + + return [hidden_states] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + # For simplicity, use uniform weighting + return self.encode_tokens(tokenize_strategy, models, tokens_list) + + +class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_gemma2_attn_mask: bool = False, + ) -> None: + super().__init__( + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + is_partial, + ) + self.apply_gemma2_attn_mask = apply_gemma2_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return ( + os.path.splitext(image_abs_path)[0] + + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + ) + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state = data["hidden_state"] + return [hidden_state] + + def cache_batch_outputs( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + infos: List, + ): + lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( + text_encoding_strategy + ) + captions = [info.caption for info in infos] + + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights( + captions + ) + with torch.no_grad(): + hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + )[0] + else: + tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state = lumina_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + )[0] + + if hidden_state.dtype == torch.bfloat16: + hidden_state = hidden_state.float() + + hidden_state = hidden_state.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state_i = hidden_state[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state=hidden_state_i, + ) + else: + info.text_encoder_outputs = [hidden_state_i] + + +class LuminaLatentsCachingStrategy(LatentsCachingStrategy): + LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path( + self, absolute_path: str, image_size: Tuple[int, int] + ) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + ): + return self._default_is_disk_cached_latents_expected( + 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: + return self._default_load_latents_from_disk( + 8, npz_path, bucket_reso + ) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents( + self, + vae, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + ): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, + vae_device, + vae_dtype, + image_infos, + flip_aug, + alpha_mask, + random_crop, + multi_resolution=True, + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a99..34ffe22b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3463,6 +3463,7 @@ def get_sai_model_spec( is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, flux: str = None, + lumina: str = None, ): timestamp = time.time() @@ -3498,6 +3499,7 @@ def get_sai_model_spec( clip_skip=args.clip_skip, # None or int sd3=sd3, flux=flux, + lumina=lumina, ) return metadata diff --git a/lumina_train_network.py b/lumina_train_network.py new file mode 100644 index 00000000..40b84e14 --- /dev/null +++ b/lumina_train_network.py @@ -0,0 +1,192 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional, Union + +import torch +from accelerate import Accelerator + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import train_network +from library import ( + lumina_models, + flux_train_utils, + lumina_util, + lumina_train_util, + sd3_train_utils, + strategy_base, + strategy_lumina, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LuminaNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group, val_dataset_group): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + + if ( + args.cache_text_encoder_outputs_to_disk + and not args.cache_text_encoder_outputs + ): + logger.warning("Enabling cache_text_encoder_outputs due to disk caching") + args.cache_text_encoder_outputs = True + + train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) + + self.train_gemma2 = not args.network_train_unet_only + + def load_target_model(self, args, weight_dtype, accelerator): + loading_dtype = None if args.fp8 else weight_dtype + + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + "cpu", + disable_mmap=args.disable_mmap_load_safetensors, + ) + + # if args.blocks_to_swap: + # logger.info(f'Enabling block swap: {args.blocks_to_swap}') + # model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # self.is_swapping_blocks = True + + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + + return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model + + def get_tokenize_strategy(self, args): + return strategy_lumina.LuminaTokenizeStrategy( + args.gemma2_max_token_length, args.tokenizer_cache_dir + ) + + def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + return strategy_lumina.LuminaLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, False + ) + + def get_text_encoding_strategy(self, args): + return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask) + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_gemma2] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_gemma2, + apply_gemma2_attn_mask=args.apply_gemma2_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, + args, + accelerator: Accelerator, + unet, + vae, + text_encoders, + dataset, + weight_dtype, + ): + for text_encoder in text_encoders: + text_encoder_outputs_caching_strategy = ( + self.get_text_encoder_outputs_caching_strategy(args) + ) + if text_encoder_outputs_caching_strategy is not None: + text_encoder_outputs_caching_strategy.cache_batch_outputs( + self.get_tokenize_strategy(args), + [text_encoder], + self.get_text_encoding_strategy(args), + dataset, + ) + + def sample_images( + self, + accelerator, + args, + epoch, + global_step, + device, + ae, + tokenizer, + text_encoder, + lumina, + ): + lumina_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + lumina, + ae, + self.get_models_for_text_encoding(args, accelerator, text_encoder), + self.sample_prompts_te_outputs, + ) + + # Remaining methods maintain similar structure to flux implementation + # with Lumina-specific model calls and strategies + + def get_noise_scheduler( + self, args: argparse.Namespace, device: torch.device + ) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=args.discrete_flow_shift + ) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + # not sure, they use same flux vae + def shift_scale_latents(self, args, latents): + return latents + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + lumina_train_utils.add_lumina_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = LuminaNetworkTrainer() + trainer.train(args) From c0caf33e3fa7a99c2160946e42d4ef7b8d7660a4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 15 Feb 2025 16:38:59 +0800 Subject: [PATCH 02/73] update --- library/lumina_util.py | 8 -- lumina_train_network.py | 175 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 12 deletions(-) diff --git a/library/lumina_util.py b/library/lumina_util.py index 990f8c68..b47e057a 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -108,14 +108,6 @@ def load_gemma2( logger.info(f"Loaded Gemma2: {info}") return gemma2 -def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): - img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] - img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) - return img_ids - - def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 diff --git a/lumina_train_network.py b/lumina_train_network.py index 40b84e14..db329a9b 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -53,7 +53,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): self.train_gemma2 = not args.network_train_unet_only def load_target_model(self, args, weight_dtype, accelerator): - loading_dtype = None if args.fp8 else weight_dtype + loading_dtype = None if args.fp8_base else weight_dtype model = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, @@ -67,8 +67,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model.enable_block_swap(args.blocks_to_swap, accelerator.device) # self.is_swapping_blocks = True - gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") - ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + gemma2 = lumina_util.load_gemma2( + args.gemma2, weight_dtype, "cpu" + ) + ae = lumina_util.load_ae( + args.ae, weight_dtype, "cpu" + ) return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model @@ -168,11 +172,174 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def shift_scale_latents(self, args, latents): return latents + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: lumina_models.NextDiT, + network, + weight_dtype, + train_unet, + is_train=True, + ): + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = ( + flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + ) + + # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 + packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + packed_latent_height, packed_latent_width = ( + noisy_model_input.shape[2] // 2, + noisy_model_input.shape[3] // 2, + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Unpack Gemma2 outputs + gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds + if not args.apply_gemma2_attn_mask: + gemma2_attn_mask = None + + def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask): + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # NextDiT forward expects (x, t, cap_feats, cap_mask) + model_pred = unet( + x=img, # packed latents + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask, # Gemma2的attention mask + ) + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + gemma2_hidden_states=gemma2_hidden_states, + input_ids=input_ids, + timesteps=timesteps, + gemma2_attn_mask=gemma2_attn_mask, + ) + + # unpack latents + model_pred = lumina_util.unpack_latents( + model_pred, packed_latent_height, packed_latent_width + ) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type( + args, model_pred, noisy_model_input, sigmas + ) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if ( + "diff_output_preservation" in custom_attributes + and custom_attributes["diff_output_preservation"] + ): + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + gemma2_hidden_states=gemma2_hidden_states[ + diff_output_pr_indices + ], + input_ids=input_ids[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + gemma2_attn_mask=( + gemma2_attn_mask[diff_output_pr_indices] + if gemma2_attn_mask is not None + else None + ), + ) + network.set_multiplier(1.0) + + model_pred_prior = lumina_util.unpack_latents( + model_pred_prior, packed_latent_height, packed_latent_width + ) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + return train_util.get_sai_model_spec( + None, args, False, True, False, lumina="lumina2" + ) + + def update_metadata(self, metadata, args): + metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + text_encoder.model.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8( + self, index, text_encoder, te_weight_dtype, weight_dtype + ): + logger.info( + f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" + ) + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.model.embed_tokens.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + nextdit: lumina_models.Nextdit = unet + nextdit = accelerator.prepare( + nextdit, device_placement=[not self.is_swapping_blocks] + ) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( + accelerator.device + ) # reduce peak memory usage + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + return nextdit def setup_parser() -> argparse.ArgumentParser: From 7323ee1b9dbfd723ee767b7faeee8833421b832d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 15 Feb 2025 17:10:34 +0800 Subject: [PATCH 03/73] update lora_lumina --- networks/lora_lumina.py | 1011 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1011 insertions(+) create mode 100644 networks/lora_lumina.py diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py new file mode 100644 index 00000000..d554ce13 --- /dev/null +++ b/networks/lora_lumina.py @@ -0,0 +1,1011 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of lumina as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + lumina, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim for JointTransformerBlock + attn_dim = kwargs.get("attn_dim", None) # attention dimension + mlp_dim = kwargs.get("mlp_dim", None) # MLP dimension + mod_dim = kwargs.get("mod_dim", None) # modulation dimension + refiner_dim = kwargs.get("refiner_dim", None) # refiner blocks dimension + + if attn_dim is not None: + attn_dim = int(attn_dim) + if mlp_dim is not None: + mlp_dim = int(mlp_dim) + if mod_dim is not None: + mod_dim = int(mod_dim) + if refiner_dim is not None: + refiner_dim = int(refiner_dim) + + type_dims = [attn_dim, mlp_dim, mod_dim, refiner_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims for embedders + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] + assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)" + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"] + LORA_PREFIX_LUMINA = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder + + def __init__( + self, + text_encoders, # Now this will be a single Gemma2 model + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + + self.type_dims = type_dims + self.in_dims = in_dims + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + + # create module instances + def create_modules( + is_lumina: bool, + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # for handling embedders + module = root_module + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder (Gemma2) + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + + logger.info(f"create LoRA for Gemma2 Text Encoder:") + text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) + + # Handle embedders + if self.in_dims: + for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims): + loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # # split qkv + # for key in list(state_dict.keys()): + # if "double" in key and "qkv" in key: + # split_dims = [3072] * 3 + # elif "single" in key and "linear1" in key: + # split_dims = [3072] * 3 + [12288] + # else: + # continue + + # weight = state_dict[key] + # lora_name = key.split(".")[0] + + # if key not in state_dict: + # continue # already merged + + # # (rank, in_dim) * 3 + # down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # # (split dim, rank) * 3 + # up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + # alpha = state_dict.pop(f"{lora_name}.alpha") + + # # merge down weight + # down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # # merge up weight (sum of split_dim, rank*3) + # rank = up_weights[0].size(1) + # up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + # i = 0 + # for j in range(len(split_dims)): + # up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + # i += split_dims[j] + + # state_dict[f"{lora_name}.lora_down.weight"] = down_weight + # state_dict[f"{lora_name}.lora_up.weight"] = up_weight + # state_dict[f"{lora_name}.alpha"] = alpha + + # # print( + # # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # # ) + # print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_LUMINA): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te_loras = [lora for lora in self.text_encoder_loras] + if len(te_loras) > 0: + logger.info(f"Text Encoder: {len(te_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) From a00b06bc978c80502850a869c845877aeb451003 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 15 Feb 2025 14:56:11 -0500 Subject: [PATCH 04/73] Lumina 2 and Gemma 2 model loading --- library/lumina_models.py | 35 ++++++++++++-------- library/lumina_util.py | 66 +++++++++++++++++++++++--------------- library/strategy_lumina.py | 2 ++ lumina_train_network.py | 2 +- 4 files changed, 65 insertions(+), 40 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 43b1e9c6..3f2e854e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -21,7 +21,8 @@ import torch.nn.functional as F try: from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: +except ModuleNotFoundError: + import warnings warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") memory_efficient_attention = None @@ -39,17 +40,20 @@ except: class LuminaParams: """Parameters for Lumina model configuration""" patch_size: int = 2 - dim: int = 2592 + in_channels: int = 4 + dim: int = 4096 n_layers: int = 30 + n_refiner_layers: int = 2 n_heads: int = 24 n_kv_heads: int = 8 + multiple_of: int = 256 axes_dims: List[int] = None axes_lens: List[int] = None - qk_norm: bool = False, - ffn_dim_multiplier: Optional[float] = None, - norm_eps: float = 1e-5, - scaling_factor: float = 1.0, - cap_feat_dim: int = 32, + qk_norm: bool = False + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + scaling_factor: float = 1.0 + cap_feat_dim: int = 32 def __post_init__(self): if self.axes_dims is None: @@ -62,12 +66,15 @@ class LuminaParams: """Returns the configuration for the 2B parameter model""" return cls( patch_size=2, - dim=2592, - n_layers=30, + in_channels=16, + dim=2304, + n_layers=26, n_heads=24, n_kv_heads=8, - axes_dims=[36, 36, 36], - axes_lens=[300, 512, 512] + axes_dims=[32, 32, 32], + axes_lens=[300, 512, 512], + qk_norm=True, + cap_feat_dim=2304 ) @classmethod @@ -696,8 +703,8 @@ class NextDiT(nn.Module): norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (1, 512, 512), + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], ) -> None: super().__init__() self.in_channels = in_channels @@ -1090,6 +1097,7 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, * return NextDiT( patch_size=params.patch_size, + in_channels=params.in_channels, dim=params.dim, n_layers=params.n_layers, n_heads=params.n_heads, @@ -1099,7 +1107,6 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, * qk_norm=params.qk_norm, ffn_dim_multiplier=params.ffn_dim_multiplier, norm_eps=params.norm_eps, - scaling_factor=params.scaling_factor, cap_feat_dim=params.cap_feat_dim, **kwargs, ) diff --git a/library/lumina_util.py b/library/lumina_util.py index b47e057a..f8e3f7db 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -27,14 +27,14 @@ def load_lumina_model( dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False, -) -> lumina_models.Lumina: +): logger.info("Building Lumina") with torch.device("meta"): model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype ) info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") @@ -69,30 +69,39 @@ def load_gemma2( ) -> Gemma2Model: logger.info("Building Gemma2") GEMMA2_CONFIG = { - "_name_or_path": "google/gemma-2b", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 2, - "eos_token_id": 1, - "head_dim": 256, - "hidden_act": "gelu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 16384, - "max_position_embeddings": 8192, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "pad_token_id": 0, - "rms_norm_eps": 1e-06, - "rope_scaling": null, - "rope_theta": 10000.0, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.0.dev0", - "use_cache": true, - "vocab_size": 256000 + "_name_or_path": "google/gemma-2-2b", + "architectures": [ + "Gemma2Model" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000 } + config = Gemma2Config(**GEMMA2_CONFIG) with init_empty_weights(): gemma2 = Gemma2Model._from_config(config) @@ -104,6 +113,13 @@ def load_gemma2( sd = load_safetensors( ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype ) + + for key in list(sd.keys()): + new_key = key.replace("model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = gemma2.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Gemma2: {info}") return gemma2 diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 622c019a..615f6e00 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -9,7 +9,9 @@ from library.strategy_base import ( LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy, + TextEncoderOutputsCachingStrategy ) +import numpy as np from library.utils import setup_logging setup_logging() diff --git a/lumina_train_network.py b/lumina_train_network.py index db329a9b..1f8ba613 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -345,7 +345,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) - lumina_train_utils.add_lumina_train_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) return parser From 60a76ebb72772327fcb7b2a10c87ad8f7b09f56f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:06:34 -0500 Subject: [PATCH 05/73] Add caching gemma2, add gradient checkpointing, refactor lumina model code --- library/lumina_models.py | 298 +++++++++++++++++++------------------ library/strategy_lumina.py | 108 ++++++++------ lumina_train_network.py | 113 ++++++++++---- networks/lora_lumina.py | 10 +- 4 files changed, 304 insertions(+), 225 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 3f2e854e..27194e2f 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -16,6 +16,8 @@ from dataclasses import dataclass from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F @@ -91,6 +93,25 @@ class LuminaParams: ) +class GradientCheckpointMixin(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = False + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + ############################################################################# # RMSNorm # ############################################################################# @@ -114,7 +135,7 @@ class RMSNorm(torch.nn.Module): self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def _norm(self, x): + def _norm(self, x) -> Tensor: """ Apply the RMSNorm normalization to the input tensor. @@ -125,21 +146,14 @@ class RMSNorm(torch.nn.Module): torch.Tensor: The normalized tensor. """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - return output * self.weight + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) def modulate(x, scale): @@ -151,7 +165,7 @@ def modulate(x, scale): ############################################################################# -class TimestepEmbedder(nn.Module): +class TimestepEmbedder(GradientCheckpointMixin): """ Embeds scalar timesteps into vector representations. """ @@ -203,11 +217,32 @@ class TimestepEmbedder(nn.Module): ) return embedding - def forward(self, t): + def _forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) return t_emb +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + ############################################################################# # Core NextDiT Model # @@ -284,7 +319,7 @@ class JointAttention(nn.Module): Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - with torch.amp.autocast("cuda",enabled=False): + with torch.autocast("cuda", enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) @@ -496,15 +531,15 @@ class FeedForward(nn.Module): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) -class JointTransformerBlock(nn.Module): +class JointTransformerBlock(GradientCheckpointMixin): def __init__( self, layer_id: int, dim: int, n_heads: int, - n_kv_heads: int, + n_kv_heads: Optional[int], multiple_of: int, - ffn_dim_multiplier: float, + ffn_dim_multiplier: Optional[float], norm_eps: float, qk_norm: bool, modulation=True, @@ -520,7 +555,7 @@ class JointTransformerBlock(nn.Module): value features (if using GQA), or set to None for the same as query. multiple_of (int): - ffn_dim_multiplier (float): + ffn_dim_multiplier (Optional[float]): norm_eps (float): """ @@ -554,7 +589,7 @@ class JointTransformerBlock(nn.Module): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) - def forward( + def _forward( self, x: torch.Tensor, x_mask: torch.Tensor, @@ -608,7 +643,7 @@ class JointTransformerBlock(nn.Module): return x -class FinalLayer(nn.Module): +class FinalLayer(GradientCheckpointMixin): """ The final layer of NextDiT. """ @@ -661,22 +696,21 @@ class RopeEmbedder: self.axes_dims, self.axes_lens, theta=self.theta ) - def __call__(self, ids: torch.Tensor): + def get_freqs_cis(self, ids: torch.Tensor): self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] result = [] for i in range(len(self.axes_dims)): - # import torch.distributed as dist - # if not dist.is_initialized() or dist.get_rank() == 0: - # import pdb - # pdb.set_trace() index = ( ids[:, :, i : i + 1] .repeat(1, 1, self.freqs_cis[i].shape[-1]) .to(torch.int64) ) + + axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1) + result.append( torch.gather( - self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), + axes, dim=1, index=index, ) @@ -790,76 +824,98 @@ class NextDiT(nn.Module): self.dim = dim self.n_heads = n_heads + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.t_embedder.enable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + self.final_layer.enable_gradient_checkpointing() + + print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.t_embedder.disable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.disable_gradient_checkpointing() + + self.final_layer.disable_gradient_checkpointing() + + print("Lumina: Gradient checkpointing disabled.") + def unpatchify( self, x: torch.Tensor, - img_size: List[Tuple[int, int]], - cap_size: List[int], - return_tensor=False, - ) -> List[torch.Tensor]: + width: int, + height: int, + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> torch.Tensor: """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ pH = pW = self.patch_size - imgs = [] - for i in range(x.size(0)): - H, W = img_size[i] - begin = cap_size[i] - end = begin + (H // pH) * (W // pW) - imgs.append( - x[i][begin:end] - .view(H // pH, W // pW, pH, pW, self.out_channels) + + output = [] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + output.append( + x[i][encoder_seq_len:seq_len] + .view(height // pH, width // pW, pH, pW, self.out_channels) .permute(4, 0, 2, 1, 3) .flatten(3, 4) .flatten(1, 2) ) + output = torch.stack(output, dim=0) - if return_tensor: - imgs = torch.stack(imgs, dim=0) - return imgs + return output def patchify_and_embed( self, - x: List[torch.Tensor] | torch.Tensor, + x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, ) -> Tuple[ - torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor + torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int] ]: - bsz = len(x) + bsz, channels, height, width = x.shape pH = pW = self.patch_size - device = x[0].device + device = x.device l_effective_cap_len = cap_mask.sum(dim=1).tolist() - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + encoder_seq_len = cap_mask.shape[1] - max_seq_len = max( - ( - cap_len + img_len - for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len) - ) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + image_seq_len = (height // self.patch_size) * (width // self.patch_size) + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) - position_ids = torch.zeros( - bsz, max_seq_len, 3, dtype=torch.int32, device=device - ) + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + H_tokens, W_tokens = height // pH, width // pW - position_ids[i, :cap_len, 0] = torch.arange( - cap_len, dtype=torch.int32, device=device - ) - position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len row_ids = ( torch.arange(H_tokens, dtype=torch.int32, device=device) .view(-1, 1) @@ -872,77 +928,40 @@ class NextDiT(nn.Module): .repeat(H_tokens, 1) .flatten() ) - position_ids[i, cap_len : cap_len + img_len, 1] = row_ids - position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids + position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids) + freqs_cis = self.rope_embedder.get_freqs_cis(position_ids) - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros( - *cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype - ) + cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros( - *img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype - ) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len] + + x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) # refine context for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = ( - img.view(C, H // pH, pH, W // pW, pW) - .permute(1, 3, 2, 4, 0) - .flatten(2) - .flatten(0, 1) - ) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros( - bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype - ) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) - for i in range(bsz): - padded_img_embed[i, : l_effective_img_len[i]] = x[i] - padded_img_mask[i, : l_effective_img_len[i]] = True + x = self.x_embedder(x) - padded_img_embed = self.x_embedder(padded_img_embed) for layer in self.noise_refiner: - padded_img_embed = layer( - padded_img_embed, padded_img_mask, img_freqs_cis, t - ) + x = layer(x, x_mask, img_freqs_cis, t) - mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - padded_full_embed = torch.zeros( - bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype - ) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] + joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype) + attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len] + joint_hidden_states[i, cap_len:seq_len] = x[i] - mask[i, : cap_len + img_len] = True - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[ - i, :img_len - ] + x = joint_hidden_states - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths def forward(self, x, t, cap_feats, cap_mask): """ @@ -950,30 +969,19 @@ class NextDiT(nn.Module): t: (N,) tensor of diffusion timesteps y: (N,) tensor of text tokens/features """ - - # import torch.distributed as dist - # if not dist.is_initialized() or dist.get_rank() == 0: - # import pdb - # pdb.set_trace() - # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt") + _, _, height, width = x.shape # B, C, H, W t = self.t_embedder(t) # (N, D) - adaln_input = t + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - cap_feats = self.cap_embedder( - cap_feats - ) # (N, L, D) # todo check if able to batchify w.o. redundant compute - - x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed( x, cap_feats, cap_mask, t ) - freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, t) - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor) + x = self.final_layer(x, t) + x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) return x diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 615f6e00..6feea387 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -3,7 +3,7 @@ import os from typing import Any, List, Optional, Tuple, Union import torch -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast from library import train_util from library.strategy_base import ( LatentsCachingStrategy, @@ -27,34 +27,35 @@ class LuminaTokenizeStrategy(TokenizeStrategy): def __init__( self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None ) -> None: - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( GEMMA_ID, cache_dir=tokenizer_cache_dir ) self.tokenizer.padding_side = "right" if max_length is None: - self.max_length = self.tokenizer.model_max_length + self.max_length = 256 else: self.max_length = max_length - def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: text = [text] if isinstance(text, str) else text encodings = self.tokenizer( text, - padding="max_length", max_length=self.max_length, return_tensors="pt", + padding=True, + pad_to_multiple_of=8, truncation=True, ) - return [encodings.input_ids] + return encodings.input_ids, encodings.attention_mask def tokenize_with_weights( self, text: str | List[str] - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: # Gemma doesn't support weighted prompts, return uniform weights - tokens = self.tokenize(text) + tokens, attention_masks = self.tokenize(text) weights = [torch.ones_like(t) for t in tokens] - return tokens, weights + return tokens, attention_masks, weights class LuminaTextEncodingStrategy(TextEncodingStrategy): @@ -66,50 +67,39 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], + tokens: torch.Tensor, + attention_masks: torch.Tensor, apply_gemma2_attn_mask: Optional[bool] = None, - ) -> List[torch.Tensor]: - + ) -> torch.Tensor: if apply_gemma2_attn_mask is None: apply_gemma2_attn_mask = self.apply_gemma2_attn_mask text_encoder = models[0] - input_ids = tokens[0].to(text_encoder.device) - attention_mask = None - position_ids = None - if apply_gemma2_attn_mask: - # Create attention mask (1 for non-padding, 0 for padding) - attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to( - text_encoder.device - ) + # Create position IDs + position_ids = attention_masks.cumsum(-1) - 1 + position_ids.masked_fill_(attention_masks == 0, 1) - # Create position IDs - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + outputs = text_encoder( + input_ids=tokens.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None, + position_ids=position_ids.to(text_encoder.device), + output_hidden_states=True, + return_dict=True, + ) - with torch.no_grad(): - outputs = text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_hidden_states=True, - return_dict=True, - ) - # Get the last hidden state - hidden_states = outputs.last_hidden_state - - return [hidden_states] + return outputs.hidden_states[-2] def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens_list: List[torch.Tensor], + tokens: torch.Tensor, weights_list: List[torch.Tensor], - ) -> List[torch.Tensor]: + attention_masks: torch.Tensor + ) -> torch.Tensor: # For simplicity, use uniform weighting - return self.encode_tokens(tokenize_strategy, models, tokens_list) + return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks) class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -149,6 +139,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) npz = np.load(npz_path) if "hidden_state" not in npz: return False + if "attention_mask" not in npz: + return False + if "input_ids" not in npz: + return False + if "apply_gemma2_attn_mask" not in npz: + return False + npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"] + if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -158,13 +157,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) hidden_state = data["hidden_state"] - return [hidden_state] + attention_mask = data["attention_mask"] + input_ids = data["input_ids"] + return [hidden_state, attention_mask, input_ids] def cache_batch_outputs( self, - tokenize_strategy: TokenizeStrategy, + tokenize_strategy: LuminaTokenizeStrategy, models: List[Any], - text_encoding_strategy: TextEncodingStrategy, + text_encoding_strategy: LuminaTextEncodingStrategy, infos: List, ): lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( @@ -173,35 +174,44 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) captions = [info.caption for info in infos] if self.is_weighted: - tokens_list, weights_list = tokenize_strategy.tokenize_with_weights( + tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights( captions ) with torch.no_grad(): hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens_list, weights_list - )[0] + tokenize_strategy, models, tokens, weights_list, attention_masks + ) else: - tokens = tokenize_strategy.tokenize(captions) + tokens, attention_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): hidden_state = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens - )[0] + tokenize_strategy, models, tokens, attention_masks + ) - if hidden_state.dtype == torch.bfloat16: + if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() hidden_state = hidden_state.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() + input_ids = tokens.cpu().numpy() + for i, info in enumerate(infos): hidden_state_i = hidden_state[i] + attention_mask_i = attention_mask[i] + input_ids_i = input_ids[i] + apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask if self.cache_to_disk: np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, + attention_mask=attention_mask_i, + input_ids=input_ids_i, + apply_gemma2_attn_mask=apply_gemma2_attn_mask_i, ) else: - info.text_encoder_outputs = [hidden_state_i] + info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] class LuminaLatentsCachingStrategy(LatentsCachingStrategy): diff --git a/lumina_train_network.py b/lumina_train_network.py index 1f8ba613..3d0c7062 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -62,6 +62,19 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): disable_mmap=args.disable_mmap_load_safetensors, ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 Lumina 2 model") + else: + logger.info( + "Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + # if args.blocks_to_swap: # logger.info(f'Enabling block swap: {args.blocks_to_swap}') # model.enable_block_swap(args.blocks_to_swap, accelerator.device) @@ -70,6 +83,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): gemma2 = lumina_util.load_gemma2( args.gemma2, weight_dtype, "cpu" ) + gemma2.eval() ae = lumina_util.load_ae( args.ae, weight_dtype, "cpu" ) @@ -118,17 +132,65 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): dataset, weight_dtype, ): - for text_encoder in text_encoders: - text_encoder_outputs_caching_strategy = ( - self.get_text_encoder_outputs_caching_strategy(args) - ) - if text_encoder_outputs_caching_strategy is not None: - text_encoder_outputs_caching_strategy.cache_batch_outputs( - self.get_tokenize_strategy(args), - [text_encoder], - self.get_text_encoding_strategy(args), - dataset, - ) + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + + if text_encoders[0].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[0].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move Gemma 2 back to cpu") + text_encoders[0].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) def sample_images( self, @@ -196,12 +258,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) ) + # May not need to pack/unpack? # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 - packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) - packed_latent_height, packed_latent_width = ( - noisy_model_input.shape[2] // 2, - noisy_model_input.shape[3] // 2, - ) + # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + # packed_latent_height, packed_latent_width = ( + # noisy_model_input.shape[2] // 2, + # noisy_model_input.shape[3] // 2, + # ) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -212,32 +275,30 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Unpack Gemma2 outputs gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds - if not args.apply_gemma2_attn_mask: - gemma2_attn_mask = None - def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask): + def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = unet( - x=img, # packed latents + x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask, # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask ) return model_pred model_pred = call_dit( - img=packed_noisy_model_input, + img=noisy_model_input, gemma2_hidden_states=gemma2_hidden_states, - input_ids=input_ids, timesteps=timesteps, gemma2_attn_mask=gemma2_attn_mask, ) + # May not need to pack/unpack? # unpack latents - model_pred = lumina_util.unpack_latents( - model_pred, packed_latent_height, packed_latent_width - ) + # model_pred = lumina_util.unpack_latents( + # model_pred, packed_latent_height, packed_latent_width + # ) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type( diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index d554ce13..3f6c9b41 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder @@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module): filter: Optional[str] = None, default_dim: Optional[int] = None, ) -> List[LoRAModule]: - prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER + prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER loras = [] skipped = [] @@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module): skipped_te = [] logger.info(f"create LoRA for Gemma2 Text Encoder:") - text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped @@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module): def state_dict(self, destination=None, prefix="", keep_vars=False): if not self.split_qkv: - return super().state_dict(destination, prefix, keep_vars) + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # merge qkv - state_dict = super().state_dict(destination, prefix, keep_vars) + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) new_state_dict = {} for key in list(state_dict.keys()): if "double" in key and "qkv" in key: From 16015635d24cad3d8e2149907c24715ea0a37d4f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:36:29 -0500 Subject: [PATCH 06/73] Update metadata.resolution for Lumina 2 --- library/sai_model_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 1e97c9cd..f5343924 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -237,7 +237,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None or flux is not None: + if sdxl or sd3 is not None or flux is not None or lumina is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 From 733fdc09c63eb2830081c6b531bf1115075c0f7b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 17 Feb 2025 14:52:48 +0800 Subject: [PATCH 07/73] update --- library/lumina_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 43b1e9c6..4daa6342 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -21,7 +21,7 @@ import torch.nn.functional as F try: from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: +except: warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") memory_efficient_attention = None From aa36c48685bad4bcc0fc341fdd516f0ee5c2cf01 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 17 Feb 2025 19:00:18 +0800 Subject: [PATCH 08/73] update for always use gemma2 mask --- library/lumina_train_util.py | 7 +-- library/strategy_lumina.py | 52 +++++++++------------- lumina_train_network.py | 83 +++++++++++++++++++++--------------- 3 files changed, 68 insertions(+), 74 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index d3edd262..7ade6c1b 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -227,7 +227,7 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) - gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None + gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -511,11 +511,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", ) - parser.add_argument( - "--apply_gemma2_attn_mask", - action="store_true", - help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する", - ) parser.add_argument( "--guidance_scale", diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 6feea387..209f62a0 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -47,7 +47,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy): pad_to_multiple_of=8, truncation=True, ) - return encodings.input_ids, encodings.attention_mask + return [encodings.input_ids, encodings.attention_mask] def tokenize_with_weights( self, text: str | List[str] @@ -59,47 +59,36 @@ class LuminaTokenizeStrategy(TokenizeStrategy): class LuminaTextEncodingStrategy(TextEncodingStrategy): - def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None: + def __init__(self) -> None: super().__init__() - self.apply_gemma2_attn_mask = apply_gemma2_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: torch.Tensor, - attention_masks: torch.Tensor, - apply_gemma2_attn_mask: Optional[bool] = None, - ) -> torch.Tensor: - if apply_gemma2_attn_mask is None: - apply_gemma2_attn_mask = self.apply_gemma2_attn_mask - + tokens: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: text_encoder = models[0] - - # Create position IDs - position_ids = attention_masks.cumsum(-1) - 1 - position_ids.masked_fill_(attention_masks == 0, 1) + input_ids, attention_masks = tokens outputs = text_encoder( - input_ids=tokens.to(text_encoder.device), - attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None, - position_ids=position_ids.to(text_encoder.device), + input_ids=input_ids.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device), output_hidden_states=True, return_dict=True, ) - return outputs.hidden_states[-2] + return outputs.hidden_states[-2], input_ids, attention_masks def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: torch.Tensor, + tokens: List[torch.Tensor], weights_list: List[torch.Tensor], - attention_masks: torch.Tensor - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # For simplicity, use uniform weighting - return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks) + return self.encode_tokens(tokenize_strategy, models, tokens) class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -111,7 +100,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False, - apply_gemma2_attn_mask: bool = False, ) -> None: super().__init__( cache_to_disk, @@ -119,7 +107,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) skip_disk_cache_validity_check, is_partial, ) - self.apply_gemma2_attn_mask = apply_gemma2_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return ( @@ -146,7 +133,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) if "apply_gemma2_attn_mask" not in npz: return False npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"] - if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask: + if not npz_apply_gemma2_attn_mask: return False except Exception as e: logger.error(f"Error loading file: {npz_path}") @@ -174,18 +161,18 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) captions = [info.caption for info in infos] if self.is_weighted: - tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights( + tokens, weights_list = tokenize_strategy.tokenize_with_weights( captions ) with torch.no_grad(): - hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens, weights_list, attention_masks + hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens, weights_list ) else: - tokens, attention_masks = tokenize_strategy.tokenize(captions) + tokens = tokenize_strategy.tokenize(captions) with torch.no_grad(): - hidden_state = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens, attention_masks + hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens ) if hidden_state.dtype != torch.float32: @@ -200,7 +187,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] - apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask if self.cache_to_disk: np.savez( @@ -208,7 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) hidden_state=hidden_state_i, attention_mask=attention_mask_i, input_ids=input_ids_i, - apply_gemma2_attn_mask=apply_gemma2_attn_mask_i, + apply_gemma2_attn_mask=True ) else: info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] diff --git a/lumina_train_network.py b/lumina_train_network.py index 3d0c7062..00c81bce 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -64,7 +64,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): if args.fp8_base: # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + if ( + model.dtype == torch.float8_e4m3fnuz + or model.dtype == torch.float8_e5m2 + or model.dtype == torch.float8_e5m2fnuz + ): raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 Lumina 2 model") @@ -80,13 +84,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model.enable_block_swap(args.blocks_to_swap, accelerator.device) # self.is_swapping_blocks = True - gemma2 = lumina_util.load_gemma2( - args.gemma2, weight_dtype, "cpu" - ) + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") gemma2.eval() - ae = lumina_util.load_ae( - args.ae, weight_dtype, "cpu" - ) + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model @@ -104,7 +104,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) def get_text_encoding_strategy(self, args): - return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask) + return strategy_lumina.LuminaTextEncodingStrategy() def get_text_encoders_train_flags(self, args, text_encoders): return [self.train_gemma2] @@ -117,7 +117,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): args.text_encoder_batch_size, args.skip_cache_check, is_partial=self.train_gemma2, - apply_gemma2_attn_mask=args.apply_gemma2_attn_mask, ) else: return None @@ -144,11 +143,15 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[0].to( + accelerator.device, dtype=weight_dtype + ) # always not fp8 if text_encoders[0].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + self.prepare_text_encoder_fp8( + 1, text_encoders[1], text_encoders[1].dtype, weight_dtype + ) else: # otherwise, we need to convert it to target dtype text_encoders[0].to(weight_dtype) @@ -158,21 +161,39 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # cache sample prompts if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + logger.info( + f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" + ) - tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = ( + strategy_base.TokenizeStrategy.get_strategy() + ) + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + sample_prompts_te_outputs = ( + {} + ) # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + for p in [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ]: if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") + logger.info( + f"cache Text Encoder outputs for prompt: {p}" + ) tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + sample_prompts_te_outputs[p] = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_t5_attn_mask, + ) ) self.sample_prompts_te_outputs = sample_prompts_te_outputs @@ -261,10 +282,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # May not need to pack/unpack? # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) - # packed_latent_height, packed_latent_width = ( - # noisy_model_input.shape[2] // 2, - # noisy_model_input.shape[3] // 2, - # ) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -274,16 +291,18 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) # Unpack Gemma2 outputs - gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = unet( - x=img, # image latents (B, C, H, W) + x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to( + dtype=torch.int32 + ), # Gemma2的attention mask ) return model_pred @@ -326,13 +345,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): gemma2_hidden_states=gemma2_hidden_states[ diff_output_pr_indices ], - input_ids=input_ids[diff_output_pr_indices], timesteps=timesteps[diff_output_pr_indices], - gemma2_attn_mask=( - gemma2_attn_mask[diff_output_pr_indices] - if gemma2_attn_mask is not None - else None - ), + gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), ) network.set_multiplier(1.0) @@ -358,7 +372,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) def update_metadata(self, metadata, args): - metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_logit_mean"] = args.logit_mean metadata["ss_logit_std"] = args.logit_std @@ -373,7 +386,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - text_encoder.model.embed_tokens.requires_grad_(True) + text_encoder.embed_tokens.requires_grad_(True) def prepare_text_encoder_fp8( self, index, text_encoder, te_weight_dtype, weight_dtype @@ -382,7 +395,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" ) text_encoder.to(te_weight_dtype) # fp8 - text_encoder.model.embed_tokens.to(dtype=weight_dtype) + text_encoder.embed_tokens.to(dtype=weight_dtype) def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module From 44782dd7905d56fedfcb4cf8e51d162d2f2d3e23 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 11:14:38 -0500 Subject: [PATCH 09/73] Fix validation epoch divergence --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c3879531..b5f92e06 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 3365cfadd7af64c6468210f98801396ffeb4873f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:18:14 -0500 Subject: [PATCH 10/73] Fix sizes for validation split --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 34ffe22b..f9fe317f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -148,10 +148,11 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" def split_train_val( paths: List[str], + sizes: List[Optional[Tuple[int, int]]], is_training_dataset: bool, validation_split: float, validation_seed: int | None -) -> List[str]: +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -172,10 +173,12 @@ def split_train_val( # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part - return paths[0:math.ceil(len(paths) * (1 - validation_split))] + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] else: # Validation dataset we split to the second part - return paths[len(paths) - round(len(paths) * validation_split):] + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -1931,12 +1934,12 @@ class DreamBoothDataset(BaseDataset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() @@ -1969,7 +1972,7 @@ class DreamBoothDataset(BaseDataset): w, h = None, None if w is not None and h is not None: - sizes[i] = [w, h] + sizes[i] = (w, h) size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") @@ -1990,8 +1993,9 @@ class DreamBoothDataset(BaseDataset): # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: - img_paths = split_train_val( + img_paths, sizes = split_train_val( img_paths, + sizes, self.is_training_dataset, self.validation_split, self.validation_seed From 3ed7606f8840c166c3d7b8e6daa170070c749b0b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:28:34 -0500 Subject: [PATCH 11/73] Clear sizes for validation reg images to be consistent --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index f9fe317f..4eccc4a0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1990,6 +1990,7 @@ class DreamBoothDataset(BaseDataset): # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] + sizes = [] # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: From 1aa2f00e85cf7802007a394e28d52014c776df48 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:42:44 -0500 Subject: [PATCH 12/73] Fix validation epoch loss to check epoch average --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index b5f92e06..674f1cb6 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 98efbc3bb784d9246a70575349b309faef9e2ecf Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Feb 2025 00:58:53 -0500 Subject: [PATCH 13/73] Add documentation to model, use SDPA attention, sample images --- library/lumina_models.py | 419 +++++++++++++++++++--------------- library/lumina_train_util.py | 421 ++++++++++++++++++++++++++--------- library/lumina_util.py | 49 +++- library/strategy_lumina.py | 16 +- lumina_train_network.py | 71 +++--- 5 files changed, 643 insertions(+), 333 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 27194e2f..e82f3b2c 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -13,6 +13,7 @@ import math from typing import List, Optional, Tuple from dataclasses import dataclass +from einops import rearrange from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa import torch @@ -23,24 +24,16 @@ import torch.nn.functional as F try: from apex.normalization import FusedRMSNorm as RMSNorm -except ModuleNotFoundError: +except: import warnings + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") -memory_efficient_attention = None -try: - import xformers -except: - pass - -try: - from xformers.ops import memory_efficient_attention -except: - memory_efficient_attention = None @dataclass class LuminaParams: """Parameters for Lumina model configuration""" + patch_size: int = 2 in_channels: int = 4 dim: int = 4096 @@ -68,7 +61,7 @@ class LuminaParams: """Returns the configuration for the 2B parameter model""" return cls( patch_size=2, - in_channels=16, + in_channels=16, # VAE channels dim=2304, n_layers=26, n_heads=24, @@ -76,21 +69,13 @@ class LuminaParams: axes_dims=[32, 32, 32], axes_lens=[300, 512, 512], qk_norm=True, - cap_feat_dim=2304 + cap_feat_dim=2304, # Gemma 2 hidden_size ) @classmethod def get_7b_config(cls) -> "LuminaParams": """Returns the configuration for the 7B parameter model""" - return cls( - patch_size=2, - dim=4096, - n_layers=32, - n_heads=32, - n_kv_heads=8, - axes_dims=[64, 64, 64], - axes_lens=[300, 512, 512] - ) + return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512]) class GradientCheckpointMixin(nn.Module): @@ -112,6 +97,7 @@ class GradientCheckpointMixin(nn.Module): else: return self._forward(*args, **kwargs) + ############################################################################# # RMSNorm # ############################################################################# @@ -148,9 +134,18 @@ class RMSNorm(torch.nn.Module): """ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x: Tensor): + """ + Apply RMSNorm to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + """ x_dtype = x.dtype + # To handle float8 we need to convert the tensor to float x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) @@ -204,17 +199,11 @@ class TimestepEmbedder(GradientCheckpointMixin): """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=t.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def _forward(self, t): @@ -222,6 +211,7 @@ class TimestepEmbedder(GradientCheckpointMixin): t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) return t_emb + def to_cuda(x): if isinstance(x, torch.Tensor): return x.cuda() @@ -266,6 +256,7 @@ class JointAttention(nn.Module): dim (int): Number of input dimensions. n_heads (int): Number of heads. n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + qk_norm (bool): Whether to use normalization for queries and keys. """ super().__init__() @@ -295,6 +286,14 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() + self.flash_attn = False + + # self.attention_processor = xformers.ops.memory_efficient_attention + self.attention_processor = F.scaled_dot_product_attention + + def set_attention_processor(self, attention_processor): + self.attention_processor = attention_processor + @staticmethod def apply_rotary_emb( x_in: torch.Tensor, @@ -326,16 +325,12 @@ class JointAttention(nn.Module): return x_out.type_as(x_in) # copied from huggingface modeling_llama.py - def _upad_input( - self, query_layer, key_layer, value_layer, attention_mask, query_length - ): + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) - ) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -355,9 +350,7 @@ class JointAttention(nn.Module): ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape( - batch_size * kv_seq_len, self.n_local_heads, head_dim - ), + query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k @@ -373,9 +366,7 @@ class JointAttention(nn.Module): else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask - ) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -388,10 +379,10 @@ class JointAttention(nn.Module): def forward( self, - x: torch.Tensor, - x_mask: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: + x: Tensor, + x_mask: Tensor, + freqs_cis: Tensor, + ) -> Tensor: """ Args: @@ -425,7 +416,7 @@ class JointAttention(nn.Module): softmax_scale = math.sqrt(1 / self.head_dim) - if dtype in [torch.float16, torch.bfloat16]: + if self.flash_attn: # begin var_len flash attn ( query_states, @@ -459,14 +450,13 @@ class JointAttention(nn.Module): if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + output = ( - F.scaled_dot_product_attention( + self.attention_processor( xq.permute(0, 2, 1, 3), xk.permute(0, 2, 1, 3), xv.permute(0, 2, 1, 3), - attn_mask=x_mask.bool() - .view(bsz, 1, 1, seqlen) - .expand(-1, self.n_local_heads, seqlen, -1), + attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), scale=softmax_scale, ) .permute(0, 2, 1, 3) @@ -474,10 +464,47 @@ class JointAttention(nn.Module): ) output = output.flatten(-2) - return self.out(output) +def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def apply_rope( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + with torch.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + + return x_out.type_as(x_in) + + class FeedForward(nn.Module): def __init__( self, @@ -554,10 +581,13 @@ class JointTransformerBlock(GradientCheckpointMixin): n_kv_heads (Optional[int]): Number of attention heads in key and value features (if using GQA), or set to None for the same as query. - multiple_of (int): - ffn_dim_multiplier (Optional[float]): - norm_eps (float): - + multiple_of (int): Number of multiple of the hidden dimension. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use normalization for queries and keys. + modulation (bool): Whether to use modulation for the attention + layer. """ super().__init__() self.dim = dim @@ -593,32 +623,30 @@ class JointTransformerBlock(GradientCheckpointMixin): self, x: torch.Tensor, x_mask: torch.Tensor, - freqs_cis: torch.Tensor, + pe: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, ): """ Perform a forward pass through the TransformerBlock. Args: - x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + x (Tensor): Input tensor. + pe (Tensor): Rope position embedding. Returns: - torch.Tensor: Output tensor after applying attention and + Tensor: Output tensor after applying attention and feedforward layers. """ if self.modulation: assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( - adaln_input - ).chunk(4, dim=1) + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( self.attention( modulate(self.attention_norm1(x), scale_msa), x_mask, - freqs_cis, + pe, ) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( @@ -632,7 +660,7 @@ class JointTransformerBlock(GradientCheckpointMixin): self.attention( self.attention_norm1(x), x_mask, - freqs_cis, + pe, ) ) x = x + self.ffn_norm2( @@ -649,6 +677,14 @@ class FinalLayer(GradientCheckpointMixin): """ def __init__(self, hidden_size, patch_size, out_channels): + """ + Initialize the FinalLayer. + + Args: + hidden_size (int): Hidden size of the input features. + patch_size (int): Patch size of the input features. + out_channels (int): Number of output channels. + """ super().__init__() self.norm_final = nn.LayerNorm( hidden_size, @@ -682,39 +718,21 @@ class FinalLayer(GradientCheckpointMixin): class RopeEmbedder: - def __init__( - self, - theta: float = 10000.0, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (1, 512, 512), - ): + def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]): super().__init__() self.theta = theta self.axes_dims = axes_dims self.axes_lens = axes_lens - self.freqs_cis = NextDiT.precompute_freqs_cis( - self.axes_dims, self.axes_lens, theta=self.theta - ) + self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - def get_freqs_cis(self, ids: torch.Tensor): + def __call__(self, ids: torch.Tensor): + device = ids.device self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] result = [] for i in range(len(self.axes_dims)): - index = ( - ids[:, :, i : i + 1] - .repeat(1, 1, self.freqs_cis[i].shape[-1]) - .to(torch.int64) - ) - - axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1) - - result.append( - torch.gather( - axes, - dim=1, - index=index, - ) - ) + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1) @@ -740,11 +758,63 @@ class NextDiT(nn.Module): axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512], ) -> None: + """ + Initialize the NextDiT model. + + Args: + patch_size (int): Patch size of the input features. + in_channels (int): Number of input channels. + dim (int): Hidden size of the input features. + n_layers (int): Number of Transformer layers. + n_refiner_layers (int): Number of refiner layers. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): Multiple of the hidden size. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use query key normalization. + cap_feat_dim (int): Dimension of the caption features. + axes_dims (List[int]): List of dimensions for the axes. + axes_lens (List[int]): List of lengths for the axes. + + Returns: + None + """ super().__init__() self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + dim, + bias=True, + ), + ) + + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.x_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=dim, @@ -769,32 +839,7 @@ class NextDiT(nn.Module): for layer_id in range(n_refiner_layers) ] ) - self.context_refiner = nn.ModuleList( - [ - JointTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, 1024)) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear( - cap_feat_dim, - dim, - bias=True, - ), - ) nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) # nn.init.zeros_(self.cap_embedder[1].weight) nn.init.zeros_(self.cap_embedder[1].bias) @@ -864,15 +909,26 @@ class NextDiT(nn.Module): def unpatchify( self, - x: torch.Tensor, + x: Tensor, width: int, height: int, encoder_seq_lengths: List[int], seq_lengths: List[int], - ) -> torch.Tensor: + ) -> Tensor: """ + Unpatchify the input tensor and embed the caption features. x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) + + Args: + x (Tensor): Input tensor. + width (int): Width of the input tensor. + height (int): Height of the input tensor. + encoder_seq_lengths (List[int]): List of encoder sequence lengths. + seq_lengths (List[int]): List of sequence lengths + + Returns: + output: (N, C, H, W) """ pH = pW = self.patch_size @@ -891,13 +947,27 @@ class NextDiT(nn.Module): def patchify_and_embed( self, - x: torch.Tensor, - cap_feats: torch.Tensor, - cap_mask: torch.Tensor, - t: torch.Tensor, - ) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int] - ]: + x: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + t: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + """ + Patchify and embed the input image and caption features. + + Args: + x: (N, C, H, W) image latents + cap_feats: (N, C, D) caption features + cap_mask: (N, C, D) caption attention mask + t: (N), T timesteps + + Returns: + Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths + + + """ bsz, channels, height, width = x.shape pH = pW = self.patch_size device = x.device @@ -915,40 +985,35 @@ class NextDiT(nn.Module): H_tokens, W_tokens = height // pH, width // pW position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len - row_ids = ( - torch.arange(H_tokens, dtype=torch.int32, device=device) - .view(-1, 1) - .repeat(1, W_tokens) - .flatten() - ) - col_ids = ( - torch.arange(W_tokens, dtype=torch.int32, device=device) - .view(1, -1) - .repeat(H_tokens, 1) - .flatten() - ) - position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids - position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids + position_ids[i, cap_len:seq_len, 0] = cap_len - freqs_cis = self.rope_embedder.get_freqs_cis(position_ids) + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + position_ids[i, cap_len:seq_len, 1] = row_ids + position_ids[i, cap_len:seq_len, 2] = col_ids + + # Get combinded rotary embeddings + freqs_cis = self.rope_embedder(position_ids) + + # Create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len] - x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) - x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - - # refine context + # Refine caption context for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) + x = self.x_embedder(x) + # Refine image context for layer in self.noise_refiner: x = layer(x, x_mask, img_freqs_cis, t) @@ -963,19 +1028,23 @@ class NextDiT(nn.Module): return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths - def forward(self, x, t, cap_feats, cap_mask): + def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor: """ Forward pass of NextDiT. - t: (N,) tensor of diffusion timesteps - y: (N,) tensor of text tokens/features + Args: + x: (N, C, H, W) image latents + t: (N,) tensor of diffusion timesteps + cap_feats: (N, L, D) caption features + cap_mask: (N, L) caption attention mask + + Returns: + x: (N, C, H, W) denoised latents """ - _, _, height, width = x.shape # B, C, H, W + _, _, height, width = x.shape # B, C, H, W t = self.t_embedder(t) # (N, D) cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed( - x, cap_feats, cap_mask, t - ) + x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) for layer in self.layers: x = layer(x, mask, freqs_cis, t) @@ -986,7 +1055,14 @@ class NextDiT(nn.Module): return x def forward_with_cfg( - self, x, t, cap_feats, cap_mask, cfg_scale, cfg_trunc=100, renorm_cfg=1 + self, + x: Tensor, + t: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + cfg_scale: float, + cfg_trunc: int = 100, + renorm_cfg: float = 1.0, ): """ Forward pass of NextDiT, but also batches the unconditional forward pass @@ -996,9 +1072,10 @@ class NextDiT(nn.Module): half = x[: len(x) // 2] if t[0] < cfg_trunc: combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] - model_out = self.forward( - combined, t, cap_feats, cap_mask - ) # [2, 16, 128, 128] + assert ( + cap_mask.shape[0] == combined.shape[0] + ), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}" + model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128] # For exact reproducibility reasons, we apply classifier-free guidance on only # three channels by default. The standard approach to cfg applies it to all channels. # This can be done by uncommenting the following line and commenting-out the line following that. @@ -1009,13 +1086,9 @@ class NextDiT(nn.Module): cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) if float(renorm_cfg) > 0.0: - ori_pos_norm = torch.linalg.vector_norm( - cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True - ) + ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True) max_new_norm = ori_pos_norm * float(renorm_cfg) - new_pos_norm = torch.linalg.vector_norm( - half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True - ) + new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True) if new_pos_norm >= max_new_norm: half_eps = half_eps * (max_new_norm / new_pos_norm) else: @@ -1040,7 +1113,7 @@ class NextDiT(nn.Module): dim: List[int], end: List[int], theta: float = 10000.0, - ): + ) -> List[Tensor]: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -1057,19 +1130,17 @@ class NextDiT(nn.Module): Defaults to 10000.0. Returns: - torch.Tensor: Precomputed frequency tensor with complex + List[torch.Tensor]: Precomputed frequency tensor with complex exponentials. """ freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / ( - theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d) - ) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to( - torch.complex64 - ) # complex64 + pos = torch.arange(e, dtype=freqs_dtype, device="cpu") + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d)) + freqs = torch.outer(pos, freqs) + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2] freqs_cis.append(freqs_cis_i) return freqs_cis @@ -1102,7 +1173,7 @@ class NextDiT(nn.Module): def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): if params is None: params = LuminaParams.get_2b_config() - + return NextDiT( patch_size=params.patch_size, in_channels=params.in_channels, diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 7ade6c1b..9dac9c9f 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -2,20 +2,20 @@ import argparse import math import os import numpy as np -import toml -import json import time -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Any import torch +from torch import Tensor from accelerate import Accelerator, PartialState -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import Gemma2Model from tqdm import tqdm from PIL import Image from safetensors.torch import save_file -from library import lumina_models, lumina_util, strategy_base, train_util +from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util from library.device_utils import init_ipex, clean_memory_on_device +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler init_ipex() @@ -30,19 +30,38 @@ logger = logging.getLogger(__name__) # region sample images +@torch.no_grad() def sample_images( accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, - nextdit, - ae, - gemma2_model, - sample_prompts_gemma2_outputs, - prompt_replacement=None, - controlnet=None + epoch: int, + global_step: int, + nextdit: lumina_models.NextDiT, + vae: torch.nn.Module, + gemma2_model: Gemma2Model, + sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, ): - if steps == 0: + """ + Generate sample images using the NextDiT model. + + Args: + accelerator (Accelerator): Accelerator instance. + args (argparse.Namespace): Command-line arguments. + epoch (int): Current epoch number. + global_step (int): Current global step number. + nextdit (lumina_models.NextDiT): The NextDiT model instance. + vae (torch.nn.Module): The VAE module. + gemma2_model (Gemma2Model): The Gemma2 model instance. + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample. + prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None. + controlnet:: ControlNet model + + Returns: + None + """ + if global_step == 0: if not args.sample_at_first: return else: @@ -53,11 +72,15 @@ def sample_images( if epoch is None or epoch % args.sample_every_n_epochs != 0: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return + assert ( + args.sample_prompts is not None + ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください" + logger.info("") - logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}") if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -87,22 +110,21 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(), accelerator.autocast(): - for prompt_dict in prompts: - sample_image_inference( - accelerator, - args, - nextdit, - gemma2_model, - ae, - save_dir, - prompt_dict, - epoch, - steps, - sample_prompts_gemma2_outputs, - prompt_replacement, - controlnet - ) + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dict, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. @@ -110,23 +132,22 @@ def sample_images( for i in range(distributed_state.num_processes): per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - sample_image_inference( - accelerator, - args, - nextdit, - gemma2_model, - ae, - save_dir, - prompt_dict, - epoch, - steps, - sample_prompts_gemma2_outputs, - prompt_replacement, - controlnet - ) + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dict, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -135,43 +156,60 @@ def sample_images( clean_memory_on_device(accelerator.device) +@torch.no_grad() def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - nextdit, - gemma2_model, - ae, - save_dir, - prompt_dict, - epoch, - steps, - sample_prompts_gemma2_outputs, - prompt_replacement, - # controlnet + nextdit: lumina_models.NextDiT, + gemma2_model: Gemma2Model, + vae: torch.nn.Module, + save_dir: str, + prompt_dict: Dict[str, str], + epoch: int, + global_step: int, + sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, ): + """ + Generates sample images + + Args: + accelerator (Accelerator): Accelerator object + args (argparse.Namespace): Arguments object + nextdit (lumina_models.NextDiT): NextDiT model + gemma2_model (Gemma2Model): Gemma2 model + vae (torch.nn.Module): VAE model + save_dir (str): Directory to save images + prompt_dict (Dict[str, str]): Prompt dictionary + epoch (int): Epoch number + steps (int): Number of steps to run + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs + prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. + + Returns: + None + """ assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 20) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 3.5) - seed = prompt_dict.get("seed") + sample_steps = prompt_dict.get("sample_steps", 38) + width = prompt_dict.get("width", 1024) + height = prompt_dict.get("height", 1024) + guidance_scale: int = prompt_dict.get("scale", 3.5) + seed: int = prompt_dict.get("seed", None) controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") + negative_prompt: str = prompt_dict.get("negative_prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - # if negative_prompt is not None: - # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + generator = torch.Generator(device=accelerator.device) if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - else: - # True random sample image generation - torch.seed() - torch.cuda.seed() + generator.manual_seed(seed) # if negative_prompt is None: # negative_prompt = "" @@ -182,7 +220,7 @@ def sample_image_inference( logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") + logger.info(f"scale: {guidance_scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -191,14 +229,16 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + gemma2_conds = [] if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: gemma2_conds = sample_prompts_gemma2_outputs[prompt] - print(f"Using cached Gemma2 outputs for prompt: {prompt}") + logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") if gemma2_model is not None: - print(f"Encoding prompt with Gemma2: {prompt}") + logger.info(f"Encoding prompt with Gemma2: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) - # strategy has apply_gemma2_attn_mask option encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) # if gemma2_conds is not cached, use encoded_gemma2_conds @@ -211,22 +251,26 @@ def sample_image_inference( gemma2_conds[i] = encoded_gemma2_conds[i] # Unpack Gemma2 outputs - gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds + gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds # sample image - weight_dtype = ae.dtype # TOFO give dtype as argument - packed_latent_height = height // 16 - packed_latent_width = width // 16 + weight_dtype = vae.dtype # TOFO give dtype as argument + latent_height = height // 8 + latent_width = width // 8 noise = torch.randn( 1, - packed_latent_height * packed_latent_width, - 16 * 2 * 2, + 16, + latent_height, + latent_width, device=accelerator.device, dtype=weight_dtype, - generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + generator=generator, ) + # Prompts are paired positive/negative + noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) - img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + # img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) # if controlnet_image is not None: @@ -235,18 +279,18 @@ def sample_image_inference( # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) - with accelerator.autocast(), torch.no_grad(): - x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + with accelerator.autocast(): + x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale) - x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + # x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image clean_memory_on_device(accelerator.device) - org_vae_device = ae.device # will be on cpu - ae.to(accelerator.device) # distributed_state.device is same as accelerator.device - with accelerator.autocast(), torch.no_grad(): - x = ae.decode(x) - ae.to(org_vae_device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(): + x = vae.decode(x) + vae.to(org_vae_device) clean_memory_on_device(accelerator.device) x = x.clamp(-1, 1) @@ -257,9 +301,9 @@ def sample_image_inference( # but adding 'enum' to the filename should be enough ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" seed_suffix = "" if seed is None else f"_{seed}" - i: int = prompt_dict["enum"] + i: int = int(prompt_dict.get("enum", 0)) img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) @@ -273,11 +317,34 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def time_shift(mu: float, sigma: float, t: torch.Tensor): +def time_shift(mu: float, sigma: float, t: Tensor): + """ + Get time shift + + Args: + mu (float): mu value. + sigma (float): sigma value. + t (Tensor): timestep. + + Return: + float: time shift + """ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + """ + Get linear function + + Args: + x1 (float, optional): x1 value. Defaults to 256. + y1 (float, optional): y1 value. Defaults to 0.5. + x2 (float, optional): x2 value. Defaults to 4096. + y2 (float, optional): y2 value. Defaults to 1.15. + + Return: + Callable[[float], float]: linear function + """ m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b @@ -290,6 +357,19 @@ def get_schedule( max_shift: float = 1.15, shift: bool = True, ) -> list[float]: + """ + Get timesteps schedule + + Args: + num_steps (int): Number of steps in the schedule. + image_seq_len (int): Sequence length of the image. + base_shift (float, optional): Base shift value. Defaults to 0.5. + max_shift (float, optional): Maximum shift value. Defaults to 1.15. + shift (bool, optional): Whether to shift the schedule. Defaults to True. + + Return: + List[float]: timesteps schedule + """ # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) @@ -301,11 +381,63 @@ def get_schedule( return timesteps.tolist() + +def denoise( + model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0 +): + """ + Denoise an image using the NextDiT model. + + Args: + model (lumina_models.NextDiT): The NextDiT model instance. + img (Tensor): The input image tensor. + txt (Tensor): The input text tensor. + txt_mask (Tensor): The input text mask tensor. + timesteps (List[float]): A list of timesteps for the denoising process. + guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0. + + Returns: + img (Tensor): Denoised tensor + """ + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + # model.prepare_block_swap_before_forward() + # block_samples = None + # block_single_samples = None + pred = model.forward_with_cfg( + x=img, # image latents (B, C, H, W) + t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=txt, # Gemma2的hidden states作为caption features + cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + cfg_scale=guidance, + ) + + img = img + (t_prev - t_curr) * pred + + # model.prepare_block_swap_before_forward() + return img + + # endregion # region train -def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): +def get_sigmas( + noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32 +) -> Tensor: + """ + Get sigmas for timesteps + + Args: + noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance. + timesteps (Tensor): A tensor of timesteps for the denoising process. + device (torch.device): The device on which the tensors are stored. + n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4. + dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32. + + Returns: + sigmas (Tensor): The sigmas tensor. + """ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) @@ -320,11 +452,22 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """Compute the density for sampling the timesteps when doing SD3 training. + """ + Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + batch_size (int): The batch size for the sampling process. + logit_mean (float, optional): The mean of the logit distribution. Defaults to None. + logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None. + mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. """ if weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). @@ -338,12 +481,19 @@ def compute_density_for_timestep_sampling( return u -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor: """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + sigmas (Tensor, optional): The sigmas tensor. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. """ if weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() @@ -355,9 +505,24 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]: + """ + Get noisy model input and timesteps. + + Args: + args (argparse.Namespace): Arguments. + noise_scheduler (noise_scheduler): Noise scheduler. + latents (Tensor): Latents. + noise (Tensor): Latent noise. + device (torch.device): Device. + dtype (torch.dtype): Data type + + Return: + Tuple[Tensor, Tensor, Tensor]: + noisy model input + timesteps + sigmas + """ bsz, _, h, w = latents.shape sigmas = None @@ -412,7 +577,21 @@ def get_noisy_model_input_and_timesteps( return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas -def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): +def apply_model_prediction_type( + args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor +) -> Tuple[Tensor, Optional[Tensor]]: + """ + Apply model prediction type to the model prediction and the sigmas. + + Args: + args (argparse.Namespace): Arguments. + model_pred (Tensor): Model prediction. + noisy_model_input (Tensor): Noisy model input. + sigmas (Tensor): Sigmas. + + Return: + Tuple[Tensor, Optional[Tensor]]: + """ weighting = None if args.model_prediction_type == "raw": pass @@ -433,10 +612,22 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): def save_models( ckpt_path: str, lumina: lumina_models.NextDiT, - sai_metadata: Optional[dict], + sai_metadata: Dict[str, Any], save_dtype: Optional[torch.dtype] = None, use_mem_eff_save: bool = False, ): + """ + Save the model to the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + lumina (lumina_models.NextDiT): NextDIT model. + sai_metadata (Optional[dict]): Metadata for the SAI model. + save_dtype (Optional[torch.dtype]): Data + + Return: + None + """ state_dict = {} def update_sd(prefix, sd): @@ -458,7 +649,9 @@ def save_lumina_model_on_train_end( args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT ): def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2" + ) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -469,15 +662,29 @@ def save_lumina_model_on_train_end( def save_lumina_model_on_epoch_end_or_stepwise( args: argparse.Namespace, on_epoch_end: bool, - accelerator, + accelerator: Accelerator, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, lumina: lumina_models.NextDiT, ): - def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + """ + Save the model to the checkpoint path. + + Args: + args (argparse.Namespace): Arguments. + save_dtype (torch.dtype): Data type. + epoch (int): Epoch. + global_step (int): Global step. + lumina (lumina_models.NextDiT): NextDIT model. + + Return: + None + """ + + def sd_saver(ckpt_file: str, epoch_no: int, global_step: int): + sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( diff --git a/library/lumina_util.py b/library/lumina_util.py index f8e3f7db..f404e775 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -11,23 +11,33 @@ from safetensors.torch import load_file from transformers import Gemma2Config, Gemma2Model from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - from library import lumina_models, flux_models from library.utils import load_safetensors +import logging + +setup_logging() +logger = logging.getLogger(__name__) MODEL_VERSION_LUMINA_V2 = "lumina2" def load_lumina_model( ckpt_path: str, dtype: torch.dtype, - device: Union[str, torch.device], + device: torch.device, disable_mmap: bool = False, ): + """ + Load the Lumina model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (torch.device): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + + Returns: + model (lumina_models.NextDiT): The loaded model. + """ logger.info("Building Lumina") with torch.device("meta"): model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) @@ -46,6 +56,18 @@ def load_ae( device: Union[str, torch.device], disable_mmap: bool = False, ) -> flux_models.AutoEncoder: + """ + Load the AutoEncoder model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + + Returns: + ae (flux_models.AutoEncoder): The loaded model. + """ logger.info("Building AutoEncoder") with torch.device("meta"): # dev and schnell have the same AE params @@ -67,6 +89,19 @@ def load_gemma2( disable_mmap: bool = False, state_dict: Optional[dict] = None, ) -> Gemma2Model: + """ + Load the Gemma2 model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + state_dict (Optional[dict], optional): The state dict to load. Defaults to None. + + Returns: + gemma2 (Gemma2Model): The loaded model + """ logger.info("Building Gemma2") GEMMA2_CONFIG = { "_name_or_path": "google/gemma-2-2b", diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 209f62a0..0a6a7f29 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -130,11 +130,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) return False if "input_ids" not in npz: return False - if "apply_gemma2_attn_mask" not in npz: - return False - npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"] - if not npz_apply_gemma2_attn_mask: - return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -142,11 +137,17 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) return True def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + """ + Load outputs from a npz file + + Returns: + List[np.ndarray]: hidden_state, input_ids, attention_mask + """ data = np.load(npz_path) hidden_state = data["hidden_state"] attention_mask = data["attention_mask"] input_ids = data["input_ids"] - return [hidden_state, attention_mask, input_ids] + return [hidden_state, input_ids, attention_mask] def cache_batch_outputs( self, @@ -193,8 +194,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) info.text_encoder_outputs_npz, hidden_state=hidden_state_i, attention_mask=attention_mask_i, - input_ids=input_ids_i, - apply_gemma2_attn_mask=True + input_ids=input_ids_i ) else: info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] diff --git a/lumina_train_network.py b/lumina_train_network.py index 00c81bce..81acfb51 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -2,9 +2,10 @@ import argparse import copy import math import random -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Tuple import torch +from torch import Tensor from accelerate import Accelerator from library.device_utils import clean_memory_on_device, init_ipex @@ -165,36 +166,31 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" ) - tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = ( - strategy_base.TokenizeStrategy.get_strategy() - ) - text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( - strategy_base.TextEncodingStrategy.get_strategy() - ) + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = ( {} ) # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [ - prompt_dict.get("prompt", ""), - prompt_dict.get("negative_prompt", ""), - ]: - if p not in sample_prompts_te_outputs: - logger.info( - f"cache Text Encoder outputs for prompt: {p}" - ) - tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = ( - text_encoding_strategy.encode_tokens( - tokenize_strategy, - text_encoders, - tokens_and_masks, - args.apply_t5_attn_mask, - ) - ) + for prompt_dict in sample_prompts: + prompts = [prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", "")] + logger.info( + f"cache Text Encoder outputs for prompt: {prompts[0]}" + ) + tokens_and_masks = tokenize_strategy.tokenize(prompts) + sample_prompts_te_outputs[prompts[0]] = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + ) + ) self.sample_prompts_te_outputs = sample_prompts_te_outputs accelerator.wait_for_everyone() @@ -220,7 +216,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): epoch, global_step, device, - ae, + vae, tokenizer, text_encoder, lumina, @@ -231,7 +227,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): epoch, global_step, lumina, - ae, + vae, self.get_models_for_text_encoding(args, accelerator, text_encoder), self.sample_prompts_te_outputs, ) @@ -258,12 +254,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def get_noise_pred_and_target( self, args, - accelerator, + accelerator: Accelerator, noise_scheduler, latents, batch, - text_encoder_conds, - unet: lumina_models.NextDiT, + text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) + dit: lumina_models.NextDiT, network, weight_dtype, train_unet, @@ -296,7 +292,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) - model_pred = unet( + model_pred = dit( x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features @@ -341,7 +337,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): network.set_multiplier(0.0) with torch.no_grad(): model_pred_prior = call_dit( - img=packed_noisy_model_input[diff_output_pr_indices], + img=noisy_model_input[diff_output_pr_indices], gemma2_hidden_states=gemma2_hidden_states[ diff_output_pr_indices ], @@ -350,9 +346,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) network.set_multiplier(1.0) - model_pred_prior = lumina_util.unpack_latents( - model_pred_prior, packed_latent_height, packed_latent_width - ) + # model_pred_prior = lumina_util.unpack_latents( + # model_pred_prior, packed_latent_height, packed_latent_width + # ) model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( args, model_pred_prior, @@ -404,7 +400,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return super().prepare_unet_with_accelerator(args, accelerator, unet) # if we doesn't swap blocks, we can move the model to device - nextdit: lumina_models.Nextdit = unet + nextdit = unet + assert isinstance(nextdit, lumina_models.NextDiT) nextdit = accelerator.prepare( nextdit, device_placement=[not self.is_swapping_blocks] ) From bd16bd13ae97a02ffee34346d254384bc40c7b30 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Feb 2025 01:21:18 -0500 Subject: [PATCH 14/73] Remove unused attention, fix typo --- library/lumina_models.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index e82f3b2c..36c3b979 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -467,13 +467,6 @@ class JointAttention(nn.Module): return self.out(output) -def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - x = rearrange(x, "B H L D -> B L (H D)") - - return x - - def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -965,8 +958,6 @@ class NextDiT(nn.Module): Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths - - """ bsz, channels, height, width = x.shape pH = pW = self.patch_size @@ -993,7 +984,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:seq_len, 1] = row_ids position_ids[i, cap_len:seq_len, 2] = col_ids - # Get combinded rotary embeddings + # Get combined rotary embeddings freqs_cis = self.rope_embedder(position_ids) # Create separate rotary embeddings for captions and images From 025cca699ba0ee05b91d37e5b7779ec28d076620 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 01:29:18 -0500 Subject: [PATCH 15/73] Fix samples, LoRA training. Add system prompt, use_flash_attn --- library/config_util.py | 6 + library/lumina_models.py | 200 +++++++++++++----------- library/lumina_train_util.py | 289 ++++++++++++++++++++++++++--------- library/lumina_util.py | 86 +++++------ library/sd3_train_utils.py | 259 ++++++++++++++++++++++++++----- library/strategy_base.py | 84 ++++++++-- library/strategy_lumina.py | 153 +++++++++++++++---- library/train_util.py | 21 ++- lumina_train_network.py | 171 +++++++++------------ train_network.py | 5 +- 10 files changed, 888 insertions(+), 386 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6..ca14dfb1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -106,6 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -196,6 +198,7 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "system_prompt": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +244,7 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "system_prompt": str, } # options handled by argparse but not handled by user config @@ -526,6 +530,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + system_prompt: {dataset.system_prompt} """) if dataset.enable_bucket: @@ -559,6 +564,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} custom_attributes: {subset.custom_attributes} + system_prompt: {subset.system_prompt} """), " ") if is_dreambooth: diff --git a/library/lumina_models.py b/library/lumina_models.py index 36c3b979..f819b68f 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -14,14 +14,19 @@ from typing import List, Optional, Tuple from dataclasses import dataclass from einops import rearrange -from flash_attn import flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F +try: + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass + try: from apex.normalization import FusedRMSNorm as RMSNorm except: @@ -75,7 +80,15 @@ class LuminaParams: @classmethod def get_7b_config(cls) -> "LuminaParams": """Returns the configuration for the 7B parameter model""" - return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512]) + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512], + ) class GradientCheckpointMixin(nn.Module): @@ -248,6 +261,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + use_flash_attn=False, ): """ Initialize the Attention module. @@ -286,7 +300,7 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - self.flash_attn = False + self.use_flash_attn = use_flash_attn # self.attention_processor = xformers.ops.memory_efficient_attention self.attention_processor = F.scaled_dot_product_attention @@ -294,35 +308,63 @@ class JointAttention(nn.Module): def set_attention_processor(self, attention_processor): self.attention_processor = attention_processor - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: + def forward( + self, + x: Tensor, + x_mask: Tensor, + freqs_cis: Tensor, + ) -> Tensor: """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. + x: + x_mask: + freqs_cis: """ - with torch.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = apply_rope(xq, freqs_cis=freqs_cis) + xk = apply_rope(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if self.use_flash_attn: + output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = ( + self.attention_processor( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + return self.out(output) # copied from huggingface modeling_llama.py def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): @@ -377,46 +419,17 @@ class JointAttention(nn.Module): (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) - def forward( + def flash_attn( self, - x: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, x_mask: Tensor, - freqs_cis: Tensor, + softmax_scale, ) -> Tensor: - """ + bsz, seqlen, _, _ = q.shape - Args: - x: - x_mask: - freqs_cis: - - Returns: - - """ - bsz, seqlen, _ = x.shape - dtype = x.dtype - - xq, xk, xv = torch.split( - self.qkv(x), - [ - self.n_local_heads * self.head_dim, - self.n_local_kv_heads * self.head_dim, - self.n_local_kv_heads * self.head_dim, - ], - dim=-1, - ) - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xq = self.q_norm(xq) - xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) - xq, xk = xq.to(dtype), xk.to(dtype) - - softmax_scale = math.sqrt(1 / self.head_dim) - - if self.flash_attn: + try: # begin var_len flash attn ( query_states, @@ -425,7 +438,7 @@ class JointAttention(nn.Module): indices_q, cu_seq_lens, max_seq_lens, - ) = self._upad_input(xq, xk, xv, x_mask, seqlen) + ) = self._upad_input(q, k, v, x_mask, seqlen) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -445,27 +458,12 @@ class JointAttention(nn.Module): output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) # end var_len_flash_attn - else: - n_rep = self.n_local_heads // self.n_local_kv_heads - if n_rep >= 1: - xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - - output = ( - self.attention_processor( - xq.permute(0, 2, 1, 3), - xk.permute(0, 2, 1, 3), - xv.permute(0, 2, 1, 3), - attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), - scale=softmax_scale, - ) - .permute(0, 2, 1, 3) - .to(dtype) + return output + except NameError as e: + raise RuntimeError( + f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) - output = output.flatten(-2) - return self.out(output) - def apply_rope( x_in: torch.Tensor, @@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin): norm_eps: float, qk_norm: bool, modulation=True, + use_flash_attn=False, ) -> None: """ Initialize a TransformerBlock. @@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn) self.feed_forward = FeedForward( dim=dim, hidden_dim=4 * dim, @@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin): class RopeEmbedder: - def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]): + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + ): super().__init__() self.theta = theta self.axes_dims = axes_dims @@ -750,6 +754,7 @@ class NextDiT(nn.Module): cap_feat_dim: int = 5120, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512], + use_flash_attn=False, ) -> None: """ Initialize the NextDiT model. @@ -803,6 +808,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=False, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -828,6 +834,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -848,6 +855,7 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_layers) ] @@ -988,8 +996,20 @@ class NextDiT(nn.Module): freqs_cis = self.rope_embedder(position_ids) # Create separate rotary embeddings for captions and images - cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) - img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + cap_freqs_cis = torch.zeros( + bsz, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + bsz, + image_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 9dac9c9f..414b2849 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1,21 +1,28 @@ +import inspect +import enum import argparse import math import os import numpy as np import time -from typing import Callable, Dict, List, Optional, Tuple, Any +from typing import Callable, Dict, List, Optional, Tuple, Any, Union import torch from torch import Tensor +from torchdiffeq import odeint from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm from PIL import Image from safetensors.torch import save_file +from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util +from library.flux_models import AutoEncoder from library.device_utils import init_ipex, clean_memory_on_device from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.lumina_dpm_solver import NoiseScheduleFlow, DPM_Solver +import library.lumina_path as path init_ipex() @@ -162,12 +169,12 @@ def sample_image_inference( args: argparse.Namespace, nextdit: lumina_models.NextDiT, gemma2_model: Gemma2Model, - vae: torch.nn.Module, + vae: AutoEncoder, save_dir: str, prompt_dict: Dict[str, str], epoch: int, global_step: int, - sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -179,12 +186,12 @@ def sample_image_inference( args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model gemma2_model (Gemma2Model): Gemma2 model - vae (torch.nn.Module): VAE model + vae (AutoEncoder): VAE model save_dir (str): Directory to save images prompt_dict (Dict[str, str]): Prompt dictionary epoch (int): Epoch number steps (int): Number of steps to run - sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. Returns: @@ -192,16 +199,19 @@ def sample_image_inference( """ assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 38) - width = prompt_dict.get("width", 1024) - height = prompt_dict.get("height", 1024) - guidance_scale: int = prompt_dict.get("scale", 3.5) - seed: int = prompt_dict.get("seed", None) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + guidance_scale = float(prompt_dict.get("scale", 3.5)) + seed = prompt_dict.get("seed", None) controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") negative_prompt: str = prompt_dict.get("negative_prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -213,10 +223,10 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - height = max(64, height - height % 16) # round to divisible by 16 - width = max(64, width - width % 16) # round to divisible by 16 + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"negative_prompt: {negative_prompt}") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") @@ -232,46 +242,51 @@ def sample_image_inference( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - gemma2_conds = [] + system_prompt = args.system_prompt or "" + + # Apply system prompt to prompts + prompt = system_prompt + prompt + negative_prompt = system_prompt + negative_prompt + + # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: gemma2_conds = sample_prompts_gemma2_outputs[prompt] logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + + # Load sample prompts from Gemma 2 if gemma2_model is not None: logger.info(f"Encoding prompt with Gemma2: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) - encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) - # if gemma2_conds is not cached, use encoded_gemma2_conds - if len(gemma2_conds) == 0: - gemma2_conds = encoded_gemma2_conds - else: - # if encoded_gemma2_conds is not None, update cached gemma2_conds - for i in range(len(encoded_gemma2_conds)): - if encoded_gemma2_conds[i] is not None: - gemma2_conds[i] = encoded_gemma2_conds[i] + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) # Unpack Gemma2 outputs gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds # sample image weight_dtype = vae.dtype # TOFO give dtype as argument latent_height = height // 8 latent_width = width // 8 + latent_channels = 16 noise = torch.randn( 1, - 16, + latent_channels, latent_height, latent_width, device=accelerator.device, dtype=weight_dtype, generator=generator, ) - # Prompts are paired positive/negative - noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) - # img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) - gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) + scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -280,16 +295,25 @@ def sample_image_inference( # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(): - x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale) + x = denoise( + scheduler, + nextdit, + noise, + gemma2_hidden_states, + gemma2_attn_mask.to(accelerator.device), + neg_gemma2_hidden_states, + neg_gemma2_attn_mask.to(accelerator.device), + timesteps=timesteps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) - # x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) - - # latent to image + # Latent to image clean_memory_on_device(accelerator.device) org_vae_device = vae.device # will be on cpu vae.to(accelerator.device) # distributed_state.device is same as accelerator.device with accelerator.autocast(): - x = vae.decode(x) + x = vae.decode((x / vae.scale_factor) + vae.shift_factor) vae.to(org_vae_device) clean_memory_on_device(accelerator.device) @@ -317,30 +341,25 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def time_shift(mu: float, sigma: float, t: Tensor): - """ - Get time shift - - Args: - mu (float): mu value. - sigma (float): sigma value. - t (Tensor): timestep. - - Return: - float: time shift - """ - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) +def time_shift(mu: float, sigma: float, t: torch.Tensor): + # the following implementation was original for t=0: clean / t=1: noise + # Since we adopt the reverse, the 1-t operations are needed + t = 1 - t + t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + t = 1 - t + return t -def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: +def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]: """ Get linear function Args: - x1 (float, optional): x1 value. Defaults to 256. - y1 (float, optional): y1 value. Defaults to 0.5. - x2 (float, optional): x2 value. Defaults to 4096. - y2 (float, optional): y2 value. Defaults to 1.15. + image_seq_len, + x1 base_seq_len: int = 256, + y2 max_seq_len: int = 4096, + y1 base_shift: float = 0.5, + y2 max_shift: float = 1.15, Return: Callable[[float], float]: linear function @@ -370,51 +389,164 @@ def get_schedule( Return: List[float]: timesteps schedule """ - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = torch.linspace(1, 1 / num_steps, num_steps) # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +) -> Tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + def denoise( - model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0 + scheduler, + model: lumina_models.NextDiT, + img: Tensor, + txt: Tensor, + txt_mask: Tensor, + neg_txt: Tensor, + neg_txt_mask: Tensor, + timesteps: Union[List[float], torch.Tensor], + num_inference_steps: int = 38, + guidance_scale: float = 4.0, + cfg_trunc_ratio: float = 1.0, + cfg_normalization: bool = True, ): """ Denoise an image using the NextDiT model. Args: + scheduler (): + Noise scheduler model (lumina_models.NextDiT): The NextDiT model instance. - img (Tensor): The input image tensor. - txt (Tensor): The input text tensor. - txt_mask (Tensor): The input text mask tensor. - timesteps (List[float]): A list of timesteps for the denoising process. - guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0. + img (Tensor): + The input image latent tensor. + txt (Tensor): + The input text tensor. + txt_mask (Tensor): + The input text mask tensor. + neg_txt (Tensor): + The negative input txt tensor + neg_txt_mask (Tensor): + The negative input text mask tensor. + timesteps (List[Union[float, torch.FloatTensor]]): + A list of timesteps for the denoising process. + guidance_scale (float, optional): + The guidance scale for the denoising process. Defaults to 4.0. + cfg_trunc_ratio (float, optional): + The ratio of the timestep interval to apply normalization-based guidance scale. + cfg_normalization (bool, optional): + Whether to apply normalization-based guidance scale. Returns: - img (Tensor): Denoised tensor + img (Tensor): Denoised latent tensor """ - for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - # model.prepare_block_swap_before_forward() - # block_samples = None - # block_single_samples = None - pred = model.forward_with_cfg( - x=img, # image latents (B, C, H, W) - t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期 + + for i, t in enumerate(tqdm(timesteps)): + # compute whether apply classifier-free truncation on this timestep + do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(img.shape[0]).to(model.device) + + noise_pred_cond = model( + img, + current_timestep, cap_feats=txt, # Gemma2的hidden states作为caption features cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask - cfg_scale=guidance, ) - img = img + (t_prev - t_curr) * pred + if not do_classifier_free_truncation: + noise_pred_uncond = model( + img, + current_timestep, + cap_feats=neg_txt, # Gemma2的hidden states作为caption features + cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # apply normalization after classifier-free guidance + if cfg_normalization: + cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + img_dtype = img.dtype + + if img.dtype != img_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + img = img.to(img_dtype) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = -noise_pred + img = scheduler.step(noise_pred, t, img, return_dict=False)[0] - # model.prepare_block_swap_before_forward() return img @@ -754,3 +886,14 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。", + ) + parser.add_argument( + "--system_prompt", + type=str, + default="You are an assistant designed to generate high-quality images based on user prompts. ", + help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py index f404e775..d9c89938 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -20,11 +20,13 @@ logger = logging.getLogger(__name__) MODEL_VERSION_LUMINA_V2 = "lumina2" + def load_lumina_model( ckpt_path: str, - dtype: torch.dtype, + dtype: Optional[torch.dtype], device: torch.device, disable_mmap: bool = False, + use_flash_attn: bool = False, ): """ Load the Lumina model from the checkpoint path. @@ -34,22 +36,22 @@ def load_lumina_model( dtype (torch.dtype): The data type for the model. device (torch.device): The device to load the model on. disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False. Returns: model (lumina_models.NextDiT): The loaded model. """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - state_dict = load_safetensors( - ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype - ) + state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") return model + def load_ae( ckpt_path: str, dtype: torch.dtype, @@ -74,9 +76,7 @@ def load_ae( ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype - ) + sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae @@ -104,37 +104,35 @@ def load_gemma2( """ logger.info("Building Gemma2") GEMMA2_CONFIG = { - "_name_or_path": "google/gemma-2-2b", - "architectures": [ - "Gemma2Model" - ], - "attention_bias": False, - "attention_dropout": 0.0, - "attn_logit_softcapping": 50.0, - "bos_token_id": 2, - "cache_implementation": "hybrid", - "eos_token_id": 1, - "final_logit_softcapping": 30.0, - "head_dim": 256, - "hidden_act": "gelu_pytorch_tanh", - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2304, - "initializer_range": 0.02, - "intermediate_size": 9216, - "max_position_embeddings": 8192, - "model_type": "gemma2", - "num_attention_heads": 8, - "num_hidden_layers": 26, - "num_key_value_heads": 4, - "pad_token_id": 0, - "query_pre_attn_scalar": 256, - "rms_norm_eps": 1e-06, - "rope_theta": 10000.0, - "sliding_window": 4096, - "torch_dtype": "float32", - "transformers_version": "4.44.2", - "use_cache": True, - "vocab_size": 256000 + "_name_or_path": "google/gemma-2-2b", + "architectures": ["Gemma2Model"], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000, } config = Gemma2Config(**GEMMA2_CONFIG) @@ -145,9 +143,7 @@ def load_gemma2( sd = state_dict else: logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype - ) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) for key in list(sd.keys()): new_key = key.replace("model.", "") @@ -159,6 +155,7 @@ def load_gemma2( logger.info(f"Loaded Gemma2: {info}") return gemma2 + def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 @@ -174,6 +171,7 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + DIFFUSERS_TO_ALPHA_VLLM_MAP = { # Embedding layers "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], @@ -224,9 +222,7 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict for block_idx in range(num_double_blocks): if str(block_idx) in key: converted = pattern.replace("()", str(block_idx)) - new_key = key.replace( - converted, replacement.replace("()", str(block_idx)) - ) + new_key = key.replace(converted, replacement.replace("()", str(block_idx))) break if new_key == key: diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c4079884..6a4b39b3 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -610,6 +610,21 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import BaseOutput +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ @@ -649,22 +664,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self, num_train_timesteps: int = 1000, shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -690,6 +732,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -709,10 +754,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) - sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -720,7 +786,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -730,18 +826,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -807,7 +934,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -823,30 +954,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + prev_sample = sample + (sigma_next - sigma) * model_output - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - - # if self.config.prediction_type == "vector_field": - - denoised = sample - model_output * sigma - # 2. Convert to an ODE derivative - derivative = (sample - denoised) / sigma_hat - - dt = self.sigmas[self.step_index + 1] - sigma_hat - - prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -858,6 +969,86 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def __len__(self): return self.config.num_train_timesteps diff --git a/library/strategy_base.py b/library/strategy_base.py index 358e42f1..fad79682 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -2,7 +2,7 @@ import os import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -430,9 +430,21 @@ class LatentsCachingStrategy: bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, multi_resolution: bool = False, - ): + ) -> bool: + """ + Args: + latents_stride: stride of latents + bucket_reso: resolution of the bucket + npz_path: path to the npz file + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + multi_resolution: whether to use multi-resolution latents + + Returns: + bool + """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -451,7 +463,7 @@ class LatentsCachingStrategy: return False if flip_aug and "latents_flipped" + key_reso_suffix not in npz: return False - if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False except Exception as e: logger.error(f"Error loading file: {npz_path}") @@ -462,22 +474,35 @@ class LatentsCachingStrategy: # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( self, - encode_by_vae, - vae_device, - vae_dtype, + encode_by_vae: Callable, + vae_device: torch.device, + vae_dtype: torch.dtype, image_infos: List, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, random_crop: bool, multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + + Args: + encode_by_vae: function to encode images by VAE + vae_device: device to use for VAE + vae_dtype: dtype to use for VAE + image_infos: list of ImageInfo + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + random_crop: whether to random crop images + multi_resolution: whether to use multi-resolution latents + + Returns: + None """ from library import train_util # import here to avoid circular import img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop + image_infos, apply_alpha_mask, random_crop ) img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) @@ -519,12 +544,40 @@ class LatentsCachingStrategy: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ for SD/SDXL + + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Args: + latents_stride (Optional[int]): Stride for latents. If None, load all latents. + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask + """ if latents_stride is None: key_reso_suffix = "" else: @@ -552,6 +605,19 @@ class LatentsCachingStrategy: alpha_mask=None, key_reso_suffix="", ): + """ + Args: + npz_path (str): Path to the npz file. + latents_tensor (torch.Tensor): Latent tensor + original_size (List[int]): Original size of the image + crop_ltrb (List[int]): Crop left top right bottom + flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor + alpha_mask (Optional[torch.Tensor]): Alpha mask + key_reso_suffix (str): Key resolution suffix + + Returns: + None + """ kwargs = {} if os.path.exists(npz_path): diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 0a6a7f29..5d6e100f 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -3,13 +3,13 @@ import os from typing import Any, List, Optional, Tuple, Union import torch -from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast +from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast from library import train_util from library.strategy_base import ( LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy, - TextEncoderOutputsCachingStrategy + TextEncoderOutputsCachingStrategy, ) import numpy as np from library.utils import setup_logging @@ -37,21 +37,38 @@ class LuminaTokenizeStrategy(TokenizeStrategy): else: self.max_length = max_length - def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: + def tokenize( + self, text: Union[str, List[str]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + token input ids, attention_masks + """ text = [text] if isinstance(text, str) else text encodings = self.tokenizer( text, max_length=self.max_length, return_tensors="pt", - padding=True, + padding="max_length", pad_to_multiple_of=8, - truncation=True, ) - return [encodings.input_ids, encodings.attention_mask] + return (encodings.input_ids, encodings.attention_mask) def tokenize_with_weights( self, text: str | List[str] ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + token input ids, attention_masks, weights + """ # Gemma doesn't support weighted prompts, return uniform weights tokens, attention_masks = self.tokenize(text) weights = [torch.ones_like(t) for t in tokens] @@ -66,9 +83,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], + tokens: Tuple[torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ text_encoder = models[0] + assert isinstance(text_encoder, Gemma2Model) input_ids, attention_masks = tokens outputs = text_encoder( @@ -84,9 +112,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], - weights_list: List[torch.Tensor], + tokens: Tuple[torch.Tensor, torch.Tensor], + weights: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + weights_list (List[torch.Tensor]): Currently unused + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ # For simplicity, use uniform weighting return self.encode_tokens(tokenize_strategy, models, tokens) @@ -114,7 +153,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX ) - def is_disk_cached_outputs_expected(self, npz_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + """ + Args: + npz_path (str): Path to the npz file. + + Returns: + bool: True if the npz file is expected to be cached. + """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -141,7 +187,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) Load outputs from a npz file Returns: - List[np.ndarray]: hidden_state, input_ids, attention_mask + List[np.ndarray]: hidden_state, input_ids, attention_mask """ data = np.load(npz_path) hidden_state = data["hidden_state"] @@ -151,53 +197,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) def cache_batch_outputs( self, - tokenize_strategy: LuminaTokenizeStrategy, + tokenize_strategy: TokenizeStrategy, models: List[Any], - text_encoding_strategy: LuminaTextEncodingStrategy, - infos: List, - ): - lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( - text_encoding_strategy - ) - captions = [info.caption for info in infos] + text_encoding_strategy: TextEncodingStrategy, + batch: List[train_util.ImageInfo], + ) -> None: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + text_encoding_strategy (LuminaTextEncodingStrategy): + infos (List): List of image_info + + Returns: + None + """ + assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) + assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) + + captions = [info.system_prompt or "" + info.caption for info in batch] if self.is_weighted: - tokens, weights_list = tokenize_strategy.tokenize_with_weights( - captions + tokens, attention_masks, weights_list = ( + tokenize_strategy.tokenize_with_weights(captions) ) with torch.no_grad(): - hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens, weights_list + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, + ) ) else: tokens = tokenize_strategy.tokenize(captions) with torch.no_grad(): - hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + ) ) if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() hidden_state = hidden_state.cpu().numpy() - attention_mask = attention_masks.cpu().numpy() - input_ids = tokens.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() # (B, S) + input_ids = input_ids.cpu().numpy() # (B, S) - - for i, info in enumerate(infos): + for i, info in enumerate(batch): hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] + assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}" + if self.cache_to_disk: np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, attention_mask=attention_mask_i, - input_ids=input_ids_i + input_ids=input_ids_i, ) else: - info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] + info.text_encoder_outputs = [ + hidden_state_i, + attention_mask_i, + input_ids_i, + ] class LuminaLatentsCachingStrategy(LatentsCachingStrategy): @@ -227,7 +295,14 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): npz_path: str, flip_aug: bool, alpha_mask: bool, - ): + ) -> bool: + """ + Args: + bucket_reso (Tuple[int, int]): The resolution of the bucket. + npz_path (str): Path to the npz file. + flip_aug (bool): Whether to flip the image. + alpha_mask (bool): Whether to apply + """ return self._default_is_disk_cached_latents_expected( 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True ) @@ -241,6 +316,20 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): Optional[np.ndarray], Optional[np.ndarray], ]: + """ + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet + """ return self._default_load_latents_from_disk( 8, npz_path, bucket_reso ) # support multi-resolution diff --git a/library/train_util.py b/library/train_util.py index 4eccc4a0..230b2c4b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -195,7 +195,7 @@ class ImageInfo: self.latents_flipped: Optional[torch.Tensor] = None self.latents_npz: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = ( None # crop left top right bottom in original pixel size, not latents size ) self.cond_img_path: Optional[str] = None @@ -211,6 +211,8 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.system_prompt: Optional[str] = None + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -434,6 +436,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -464,6 +467,8 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.system_prompt = system_prompt + class DreamBoothSubset(BaseSubset): def __init__( @@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.is_reg = is_reg @@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.metadata_file = metadata_file @@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.conditioning_data_dir = conditioning_data_dir @@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: + system_prompt = subset.system_prompt or "" caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: # caption_layer = [] # for layer in self.XTI_layers: @@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset): num_train_images = 0 num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] + for subset in subsets: num_repeats = subset.num_repeats if self.is_training_dataset else 1 if num_repeats < 1: @@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset): num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: @@ -2967,7 +2980,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size diff --git a/lumina_train_network.py b/lumina_train_network.py index 81acfb51..adbf834c 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -1,17 +1,17 @@ import argparse import copy -import math -import random -from typing import Any, Optional, Union, Tuple +from typing import Any, Tuple import torch -from torch import Tensor -from accelerate import Accelerator from library.device_utils import clean_memory_on_device, init_ipex init_ipex() +from torch import Tensor +from accelerate import Accelerator + + import train_network from library import ( lumina_models, @@ -40,10 +40,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def assert_extra_args(self, args, train_dataset_group, val_dataset_group): super().assert_extra_args(args, train_dataset_group, val_dataset_group) - if ( - args.cache_text_encoder_outputs_to_disk - and not args.cache_text_encoder_outputs - ): + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning("Enabling cache_text_encoder_outputs due to disk caching") args.cache_text_encoder_outputs = True @@ -59,17 +56,14 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): model = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, loading_dtype, - "cpu", + torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, ) if args.fp8_base: # check dtype of model - if ( - model.dtype == torch.float8_e4m3fnuz - or model.dtype == torch.float8_e5m2 - or model.dtype == torch.float8_e5m2fnuz - ): + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 Lumina 2 model") @@ -92,17 +86,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model def get_tokenize_strategy(self, args): - return strategy_lumina.LuminaTokenizeStrategy( - args.gemma2_max_token_length, args.tokenizer_cache_dir - ) + return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): return [tokenize_strategy.tokenizer] def get_latents_caching_strategy(self, args): - return strategy_lumina.LuminaLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, False - ) + return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) def get_text_encoding_strategy(self, args): return strategy_lumina.LuminaTextEncodingStrategy() @@ -144,15 +134,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to( - accelerator.device, dtype=weight_dtype - ) # always not fp8 + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 if text_encoders[0].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8( - 1, text_encoders[1], text_encoders[1].dtype, weight_dtype - ) + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) else: # otherwise, we need to convert it to target dtype text_encoders[0].to(weight_dtype) @@ -162,35 +148,36 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # cache sample prompts if args.sample_prompts is not None: - logger.info( - f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" - ) + logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}") tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + system_prompt = args.system_prompt or "" sample_prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = ( - {} - ) # key: prompt, value: text encoder outputs + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: - prompts = [prompt_dict.get("prompt", ""), - prompt_dict.get("negative_prompt", "")] - logger.info( - f"cache Text Encoder outputs for prompt: {prompts[0]}" - ) - tokens_and_masks = tokenize_strategy.tokenize(prompts) - sample_prompts_te_outputs[prompts[0]] = ( - text_encoding_strategy.encode_tokens( + prompts = [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ] + for prompt in prompts: + prompt = system_prompt + prompt + if prompt in sample_prompts_te_outputs: + continue + + logger.info(f"cache Text Encoder outputs for prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( tokenize_strategy, text_encoders, tokens_and_masks, ) - ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs accelerator.wait_for_everyone() @@ -235,12 +222,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Remaining methods maintain similar structure to flux implementation # with Lumina-specific model calls and strategies - def get_noise_scheduler( - self, args: argparse.Namespace, device: torch.device - ) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler( - num_train_timesteps=1000, shift=args.discrete_flow_shift - ) + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -258,26 +241,45 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): noise_scheduler, latents, batch, - text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) + text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) dit: lumina_models.NextDiT, network, weight_dtype, train_unet, is_train=True, ): + assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) noise = torch.randn_like(latents) bsz = latents.shape[0] - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = ( - flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = lumina_train_util.compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) - # May not need to pack/unpack? - # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 - # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents` + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -289,48 +291,35 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Unpack Gemma2 outputs gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds - def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): + def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = dit( x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask.to( - dtype=torch.int32 - ), # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask ) return model_pred model_pred = call_dit( img=noisy_model_input, gemma2_hidden_states=gemma2_hidden_states, - timesteps=timesteps, gemma2_attn_mask=gemma2_attn_mask, + timesteps=timesteps, ) - # May not need to pack/unpack? - # unpack latents - # model_pred = lumina_util.unpack_latents( - # model_pred, packed_latent_height, packed_latent_width - # ) - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type( - args, model_pred, noisy_model_input, sigmas - ) + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - # flow matching loss: this is different from SD3 - target = noise - latents + # flow matching loss + target = latents - noise # differential output preservation if "custom_attributes" in batch: diff_output_pr_indices = [] for i, custom_attributes in enumerate(batch["custom_attributes"]): - if ( - "diff_output_preservation" in custom_attributes - and custom_attributes["diff_output_preservation"] - ): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: diff_output_pr_indices.append(i) if len(diff_output_pr_indices) > 0: @@ -338,9 +327,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): with torch.no_grad(): model_pred_prior = call_dit( img=noisy_model_input[diff_output_pr_indices], - gemma2_hidden_states=gemma2_hidden_states[ - diff_output_pr_indices - ], + gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices], timesteps=timesteps[diff_output_pr_indices], gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), ) @@ -363,9 +350,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec( - None, args, False, True, False, lumina="lumina2" - ) + return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2") def update_metadata(self, metadata, args): metadata["ss_weighting_scheme"] = args.weighting_scheme @@ -384,12 +369,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): text_encoder.embed_tokens.requires_grad_(True) - def prepare_text_encoder_fp8( - self, index, text_encoder, te_weight_dtype, weight_dtype - ): - logger.info( - f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" - ) + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") text_encoder.to(te_weight_dtype) # fp8 text_encoder.embed_tokens.to(dtype=weight_dtype) @@ -402,12 +383,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # if we doesn't swap blocks, we can move the model to device nextdit = unet assert isinstance(nextdit, lumina_models.NextDiT) - nextdit = accelerator.prepare( - nextdit, device_placement=[not self.is_swapping_blocks] - ) - accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( - accelerator.device - ) # reduce peak memory usage + nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() return nextdit diff --git a/train_network.py b/train_network.py index 674f1cb6..2cf11af7 100644 --- a/train_network.py +++ b/train_network.py @@ -129,7 +129,7 @@ class NetworkTrainer: if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) - def load_target_model(self, args, weight_dtype, accelerator): + def load_target_model(self, args, weight_dtype, accelerator) -> tuple: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む @@ -354,12 +354,13 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), From 6d7bec8a374c610d31986f049e2296974471f58c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 01:46:47 -0500 Subject: [PATCH 16/73] Remove non-used code --- library/lumina_train_util.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 414b2849..db9af238 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1,5 +1,4 @@ import inspect -import enum import argparse import math import os @@ -9,20 +8,16 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union import torch from torch import Tensor -from torchdiffeq import odeint from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm from PIL import Image from safetensors.torch import save_file -from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler -from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util +from library import lumina_models, strategy_base, strategy_lumina, train_util from library.flux_models import AutoEncoder from library.device_utils import init_ipex, clean_memory_on_device from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler -from library.lumina_dpm_solver import NoiseScheduleFlow, DPM_Solver -import library.lumina_path as path init_ipex() From 42a801514ccad054ac7c362ff5a9c0aa0e1e79d7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 13:48:37 -0500 Subject: [PATCH 17/73] Fix system prompt in datasets --- library/lumina_train_util.py | 2 +- library/train_util.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index db9af238..487ae2f9 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -280,7 +280,7 @@ def sample_image_inference( generator=generator, ) - scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True) + scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) # if controlnet_image is not None: diff --git a/library/train_util.py b/library/train_util.py index 230b2c4b..ded23f41 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1869,6 +1869,7 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + system_prompt: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1881,6 +1882,7 @@ class DreamBoothDataset(BaseDataset): self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split + self.system_prompt = system_prompt self.enable_bucket = enable_bucket if self.enable_bucket: @@ -2098,8 +2100,9 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += num_repeats * len(img_paths) + system_prompt = self.system_prompt or subset.system_prompt or "" for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: From ba725a84e9511abeb3a31b4bc45cd7eba4c12d65 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 18:01:09 -0500 Subject: [PATCH 18/73] Set default discrete_flow_shift to 6.0. Remove default system prompt. --- library/lumina_models.py | 125 +++++++++++++++++++---------------- library/lumina_train_util.py | 6 +- 2 files changed, 70 insertions(+), 61 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index f819b68f..365453c1 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1,9 +1,19 @@ +# Copyright Alpha VLLM/Lumina Image 2.0 and contributors # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # References: # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py @@ -13,8 +23,6 @@ import math from typing import List, Optional, Tuple from dataclasses import dataclass -from einops import rearrange - import torch from torch import Tensor from torch.utils.checkpoint import checkpoint @@ -25,6 +33,7 @@ try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa except: + # flash_attn may not be available but it is not required pass try: @@ -34,6 +43,58 @@ except: warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + ############################################################################# + # RMSNorm # + ############################################################################# + + class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x) -> Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor): + """ + Apply RMSNorm to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + """ + x_dtype = x.dtype + # To handle float8 we need to convert the tensor to float + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) + + @dataclass class LuminaParams: @@ -111,58 +172,6 @@ class GradientCheckpointMixin(nn.Module): return self._forward(*args, **kwargs) -############################################################################# -# RMSNorm # -############################################################################# - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x) -> Tensor: - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor): - """ - Apply RMSNorm to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - """ - x_dtype = x.dtype - # To handle float8 we need to convert the tensor to float - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) - def modulate(x, scale): return x * (1 + scale.unsqueeze(1)) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 487ae2f9..172d09ea 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -878,8 +878,8 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--discrete_flow_shift", type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + default=6.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。", ) parser.add_argument( "--use_flash_attn", @@ -889,6 +889,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--system_prompt", type=str, - default="You are an assistant designed to generate high-quality images based on user prompts. ", + default="", help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", ) From 48e7da2d4a844d60a4db1ac03b9a4a34a2c57720 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 20:19:24 -0500 Subject: [PATCH 19/73] Add sample batch size for Lumina --- library/lumina_models.py | 6 +- library/lumina_train_util.py | 298 +++++++++++++++++++++++------------ train_network.py | 3 + 3 files changed, 201 insertions(+), 106 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 365453c1..d86a9cb2 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -880,8 +880,8 @@ class NextDiT(nn.Module): self.n_heads = n_heads self.gradient_checkpointing = False - self.cpu_offload_checkpointing = False - self.blocks_to_swap = None + self.cpu_offload_checkpointing = False # TODO: not yet supported + self.blocks_to_swap = None # TODO: not yet supported @property def device(self): @@ -982,8 +982,8 @@ class NextDiT(nn.Module): l_effective_cap_len = cap_mask.sum(dim=1).tolist() encoder_seq_len = cap_mask.shape[1] - image_seq_len = (height // self.patch_size) * (width // self.patch_size) + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] max_seq_len = max(seq_lengths) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 172d09ea..4aa48e8b 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -4,7 +4,7 @@ import math import os import numpy as np import time -from typing import Callable, Dict, List, Optional, Tuple, Any, Union +from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator import torch from torch import Tensor @@ -32,6 +32,59 @@ logger = logging.getLogger(__name__) # region sample images +def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]: + """ + Group prompt dictionaries into batches with configurable batch size. + + Args: + prompt_dicts (list): List of dictionaries containing prompt parameters. + batch_size (int, optional): Number of prompts per batch. Defaults to None. + + Yields: + list[dict[str, str]]: Batch of prompts. + """ + # Validate batch_size + if batch_size is not None: + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size must be a positive integer or None") + + # Group prompts by their parameters + batches = {} + for prompt_dict in prompt_dicts: + # Extract parameters + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + guidance_scale = float(prompt_dict.get("scale", 3.5)) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + seed = prompt_dict.get("seed", None) + seed = int(seed) if seed is not None else None + + # Create a key based on the parameters + key = (width, height, guidance_scale, seed, sample_steps) + + # Add the prompt_dict to the corresponding batch + if key not in batches: + batches[key] = [] + batches[key].append(prompt_dict) + + # Yield each batch with its parameters + for key in batches: + prompts = batches[key] + if batch_size is None: + # Yield the entire group as a single batch + yield prompts + else: + # Split the group into batches of size `batch_size` + start = 0 + while start < len(prompts): + end = start + batch_size + batch = prompts[start:end] + yield batch + start = end + + @torch.no_grad() def sample_images( accelerator: Accelerator, @@ -39,9 +92,9 @@ def sample_images( epoch: int, global_step: int, nextdit: lumina_models.NextDiT, - vae: torch.nn.Module, + vae: AutoEncoder, gemma2_model: Gemma2Model, - sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -54,11 +107,13 @@ def sample_images( epoch (int): Current epoch number. global_step (int): Current global step number. nextdit (lumina_models.NextDiT): The NextDiT model instance. - vae (torch.nn.Module): The VAE module. + vae (AutoEncoder): The VAE module. gemma2_model (Gemma2Model): The Gemma2 model instance. - sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample. - prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None. - controlnet:: ControlNet model + sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]): + Dictionary ist of tuples containing the encoded prompts, text masks, and timestep for each sample. + prompt_replacement (Optional[Tuple[str, str]], optional): + Tuple containing the prompt and negative prompt replacements. Defaults to None. + controlnet (): ControlNet model, not yet supported Returns: None @@ -110,9 +165,12 @@ def sample_images( except Exception: pass + batch_size = args.sample_batch_size or args.train_batch_size or 1 + if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - for prompt_dict in prompts: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompts, batch_size): sample_image_inference( accelerator, args, @@ -120,7 +178,7 @@ def sample_images( gemma2_model, vae, save_dir, - prompt_dict, + prompt_dicts, epoch, global_step, sample_prompts_gemma2_outputs, @@ -135,7 +193,8 @@ def sample_images( per_process_prompts.append(prompts[i :: distributed_state.num_processes]) with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): sample_image_inference( accelerator, args, @@ -143,7 +202,7 @@ def sample_images( gemma2_model, vae, save_dir, - prompt_dict, + prompt_dicts, epoch, global_step, sample_prompts_gemma2_outputs, @@ -166,10 +225,10 @@ def sample_image_inference( gemma2_model: Gemma2Model, vae: AutoEncoder, save_dir: str, - prompt_dict: Dict[str, str], + prompt_dicts: list[Dict[str, str]], epoch: int, global_step: int, - sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]], + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -192,43 +251,6 @@ def sample_image_inference( Returns: None """ - assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = int(prompt_dict.get("sample_steps", 38)) - width = int(prompt_dict.get("width", 1024)) - height = int(prompt_dict.get("height", 1024)) - guidance_scale = float(prompt_dict.get("scale", 3.5)) - seed = prompt_dict.get("seed", None) - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - negative_prompt: str = prompt_dict.get("negative_prompt", "") - # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - seed = int(seed) if seed is not None else None - assert seed is None or seed > 0, f"Invalid seed {seed}" - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - generator = torch.Generator(device=accelerator.device) - if seed is not None: - generator.manual_seed(seed) - - # if negative_prompt is None: - # negative_prompt = "" - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - logger.info(f"prompt: {prompt}") - logger.info(f"negative_prompt: {negative_prompt}") - logger.info(f"height: {height}") - logger.info(f"width: {width}") - logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {guidance_scale}") - # logger.info(f"sample_sampler: {sampler_name}") - if seed is not None: - logger.info(f"seed: {seed}") # encode prompts tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() @@ -237,33 +259,86 @@ def sample_image_inference( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt = args.system_prompt or "" + text_conds = [] - # Apply system prompt to prompts - prompt = system_prompt + prompt - negative_prompt = system_prompt + negative_prompt + # assuming seed, width, height, sample steps, guidance are the same + width = int(prompt_dicts[0].get("width", 1024)) + height = int(prompt_dicts[0].get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 - # Get sample prompts from cache - if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: - gemma2_conds = sample_prompts_gemma2_outputs[prompt] - logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) + seed = prompt_dicts[0].get("seed", None) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + generator = torch.Generator(device=accelerator.device) + if seed is not None: + generator.manual_seed(seed) - if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: - neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] - logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + for prompt_dict in prompt_dicts: + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + negative_prompt = prompt_dict.get("negative_prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - # Load sample prompts from Gemma 2 - if gemma2_model is not None: - logger.info(f"Encoding prompt with Gemma2: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) - neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + if negative_prompt is None: + negative_prompt = "" + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {guidance_scale}") + # logger.info(f"sample_sampler: {sampler_name}") - # Unpack Gemma2 outputs - gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds - neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds + system_prompt = args.system_prompt or "" + + # Apply system prompt to prompts + prompt = system_prompt + prompt + negative_prompt = system_prompt + negative_prompt + + # Get sample prompts from cache + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + + # Load sample prompts from Gemma 2 + if gemma2_model is not None: + logger.info(f"Encoding prompt with Gemma2: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + + # Unpack Gemma2 outputs + gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds + + text_conds.append( + ( + gemma2_hidden_states.squeeze(0), + gemma2_attn_mask.squeeze(0), + neg_gemma2_hidden_states.squeeze(0), + neg_gemma2_attn_mask.squeeze(0), + ) + ) + + # Stack conditioning + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device) # sample image weight_dtype = vae.dtype # TOFO give dtype as argument @@ -279,6 +354,7 @@ def sample_image_inference( dtype=weight_dtype, generator=generator, ) + noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) @@ -294,10 +370,10 @@ def sample_image_inference( scheduler, nextdit, noise, - gemma2_hidden_states, - gemma2_attn_mask.to(accelerator.device), - neg_gemma2_hidden_states, - neg_gemma2_attn_mask.to(accelerator.device), + cond_hidden_states, + cond_attn_masks, + uncond_hidden_states, + uncond_attn_masks, timesteps=timesteps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, @@ -307,34 +383,44 @@ def sample_image_inference( clean_memory_on_device(accelerator.device) org_vae_device = vae.device # will be on cpu vae.to(accelerator.device) # distributed_state.device is same as accelerator.device - with accelerator.autocast(): - x = vae.decode((x / vae.scale_factor) + vae.shift_factor) + for img, prompt_dict in zip(x, prompt_dicts): + + img = (img / vae.scale_factor) + vae.shift_factor + + with accelerator.autocast(): + # Add a single batch image for the VAE to decode + img = vae.decode(img.unsqueeze(0)) + + img = img.clamp(-1, 1) + img = img.permute(0, 2, 3, 1) # B, H, W, C + # Scale images back to 0 to 255 + img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8) + + # Get single image + image = Image.fromarray(img[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = int(prompt_dict.get("enum", 0)) + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + vae.to(org_vae_device) clean_memory_on_device(accelerator.device) - x = x.clamp(-1, 1) - x = x.permute(0, 2, 3, 1) - image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) - - # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list - # but adding 'enum' to the filename should be enough - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - i: int = int(prompt_dict.get("enum", 0)) - img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" - image.save(os.path.join(save_dir, img_filename)) - - # send images to wandb if enabled - if "wandb" in [tracker.name for tracker in accelerator.trackers]: - wandb_tracker = accelerator.get_tracker("wandb") - - import wandb - - # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption - def time_shift(mu: float, sigma: float, t: torch.Tensor): # the following implementation was original for t=0: clean / t=1: noise @@ -879,16 +965,22 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): "--discrete_flow_shift", type=float, default=6.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。", + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0", ) parser.add_argument( "--use_flash_attn", action="store_true", - help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。", + help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", ) parser.add_argument( "--system_prompt", type=str, default="", - help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", + help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=None, + help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます", ) diff --git a/train_network.py b/train_network.py index 2cf11af7..07de30b3 100644 --- a/train_network.py +++ b/train_network.py @@ -1242,6 +1242,7 @@ class NetworkTrainer: # For --sample_at_first optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() # Reset progress bar to before sampling images optimizer_train_fn() is_tracking = len(accelerator.trackers) > 0 if is_tracking: @@ -1344,6 +1345,7 @@ class NetworkTrainer: self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) + progress_bar.unpause() # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1531,6 +1533,7 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() optimizer_train_fn() # end of epoch From 2c94d17f0554d1f468e1249e24ad8db0ca812f19 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 20:21:06 -0500 Subject: [PATCH 20/73] Fix typo --- library/lumina_train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 4aa48e8b..87f7ba36 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -110,7 +110,7 @@ def sample_images( vae (AutoEncoder): The VAE module. gemma2_model (Gemma2Model): The Gemma2 model instance. sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]): - Dictionary ist of tuples containing the encoded prompts, text masks, and timestep for each sample. + Dictionary of tuples containing the encoded prompts, text masks, and timestep for each sample. prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None. controlnet (): ControlNet model, not yet supported From fc772affbe4345c8e0d14eb53ebc883f8c5a576f Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 24 Feb 2025 14:10:24 +0800 Subject: [PATCH 21/73] =?UTF-8?q?1=E3=80=81Implement=20cfg=5Ftrunc=20calcu?= =?UTF-8?q?lation=20directly=20using=20timesteps,=20without=20intermediate?= =?UTF-8?q?=20steps.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 2、Deprecate and remove the guidance_scale parameter because it used in inference not train 3、Add inference command-line arguments --ct for cfg_trunc_ratio and --rc for renorm_cfg to control CFG truncation and renormalization during inference. --- library/lumina_models.py | 2 +- library/lumina_train_util.py | 46 +++++++++++++++++------------------- library/train_util.py | 10 ++++++++ lumina_train_network.py | 1 - 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index d86a9cb2..1a441a69 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1081,7 +1081,7 @@ class NextDiT(nn.Module): cap_feats: Tensor, cap_mask: Tensor, cfg_scale: float, - cfg_trunc: int = 100, + cfg_trunc: float = 0.25, renorm_cfg: float = 1.0, ): """ diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 87f7ba36..f54b202d 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -58,11 +58,13 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N width = max(64, width - width % 8) # round to divisible by 8 guidance_scale = float(prompt_dict.get("scale", 3.5)) sample_steps = int(prompt_dict.get("sample_steps", 38)) + cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0)) seed = prompt_dict.get("seed", None) seed = int(seed) if seed is not None else None # Create a key based on the parameters - key = (width, height, guidance_scale, seed, sample_steps) + key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg) # Add the prompt_dict to the corresponding batch if key not in batches: @@ -268,6 +270,8 @@ def sample_image_inference( width = max(64, width - width % 8) # round to divisible by 8 guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0)) sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) seed = prompt_dicts[0].get("seed", None) seed = int(seed) if seed is not None else None @@ -295,6 +299,8 @@ def sample_image_inference( logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") logger.info(f"scale: {guidance_scale}") + logger.info(f"trunc: {cfg_trunc_ratio}") + logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") system_prompt = args.system_prompt or "" @@ -375,8 +381,9 @@ def sample_image_inference( uncond_hidden_states, uncond_attn_masks, timesteps=timesteps, - num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, + cfg_trunc_ratio=cfg_trunc_ratio, + renorm_cfg=renorm_cfg, ) # Latent to image @@ -550,10 +557,9 @@ def denoise( neg_txt: Tensor, neg_txt_mask: Tensor, timesteps: Union[List[float], torch.Tensor], - num_inference_steps: int = 38, guidance_scale: float = 4.0, - cfg_trunc_ratio: float = 1.0, - cfg_normalization: bool = True, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, ): """ Denoise an image using the NextDiT model. @@ -578,21 +584,17 @@ def denoise( The guidance scale for the denoising process. Defaults to 4.0. cfg_trunc_ratio (float, optional): The ratio of the timestep interval to apply normalization-based guidance scale. - cfg_normalization (bool, optional): - Whether to apply normalization-based guidance scale. - + renorm_cfg (float, optional): + The factor to limit the maximum norm after guidance. Default: 1.0 Returns: img (Tensor): Denoised latent tensor """ for i, t in enumerate(tqdm(timesteps)): - # compute whether apply classifier-free truncation on this timestep - do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio - # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(img.shape[0]).to(model.device) + current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device) noise_pred_cond = model( img, @@ -601,7 +603,8 @@ def denoise( cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask ) - if not do_classifier_free_truncation: + # compute whether to apply classifier-free guidance based on current timestep + if current_timestep[0] < cfg_trunc_ratio: noise_pred_uncond = model( img, current_timestep, @@ -610,10 +613,12 @@ def denoise( ) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance - if cfg_normalization: - cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_pred = noise_pred * (cond_norm / noise_norm) + if float(renorm_cfg) > 0.0: + cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True) + max_new_norm = cond_norm * float(renorm_cfg) + noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True) + if noise_norm >= max_new_norm: + noise_pred = noise_pred * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond @@ -932,13 +937,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the NextDIT.1 dev variant is a guidance distilled model", - ) - parser.add_argument( "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], diff --git a/library/train_util.py b/library/train_util.py index ded23f41..18aceaf7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6188,6 +6188,16 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["controlnet_image"] = m.group(1) continue + m = re.match(r"ct (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["cfg_trunc_ratio"] = float(m.group(1)) + continue + + m = re.match(r"rc (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["renorm_cfg"] = float(m.group(1)) + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) diff --git a/lumina_train_network.py b/lumina_train_network.py index adbf834c..0fd4da6b 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -357,7 +357,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): metadata["ss_logit_mean"] = args.logit_mean metadata["ss_logit_std"] = args.logit_std metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale metadata["ss_timestep_sampling"] = args.timestep_sampling metadata["ss_sigmoid_scale"] = args.sigmoid_scale metadata["ss_model_prediction_type"] = args.model_prediction_type From 5f9047c8cf4f28019d1365cbc7e439f5afbdda0a Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 26 Feb 2025 01:00:35 +0800 Subject: [PATCH 22/73] add truncation when > max_length --- library/lumina_train_util.py | 1 - library/strategy_lumina.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index f54b202d..20df7eef 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -320,7 +320,6 @@ def sample_image_inference( # Load sample prompts from Gemma 2 if gemma2_model is not None: - logger.info(f"Encoding prompt with Gemma2: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 5d6e100f..c9e65423 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -54,6 +54,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy): max_length=self.max_length, return_tensors="pt", padding="max_length", + truncation=True, pad_to_multiple_of=8, ) return (encodings.input_ids, encodings.attention_mask) From ce37c08b9a3b8e6567c70712f9d6899a304e98b6 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 26 Feb 2025 11:20:03 +0800 Subject: [PATCH 23/73] clean code and add finetune code --- library/lumina_train_util.py | 212 ++++++-- lumina_train.py | 953 +++++++++++++++++++++++++++++++++++ lumina_train_network.py | 37 +- 3 files changed, 1118 insertions(+), 84 deletions(-) create mode 100644 lumina_train.py diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 20df7eef..ca039167 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) # region sample images -def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]: +def batchify( + prompt_dicts, batch_size=None +) -> Generator[list[dict[str, str]], None, None]: """ Group prompt dictionaries into batches with configurable batch size. @@ -64,7 +66,15 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N seed = int(seed) if seed is not None else None # Create a key based on the parameters - key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg) + key = ( + width, + height, + guidance_scale, + seed, + sample_steps, + cfg_trunc_ratio, + renorm_cfg, + ) # Add the prompt_dict to the corresponding batch if key not in batches: @@ -131,7 +141,9 @@ def sample_images( if epoch is None or epoch % args.sample_every_n_epochs != 0: return else: - if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if ( + global_step % args.sample_every_n_steps != 0 or epoch is not None + ): # steps is not divisible or end of epoch return assert ( @@ -139,12 +151,21 @@ def sample_images( ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください" logger.info("") - logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}") - if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: - logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.info( + f"generating sample images at step / サンプル画像生成 ステップ: {global_step}" + ) + if ( + not os.path.isfile(args.sample_prompts) + and sample_prompts_gemma2_outputs is None + ): + logger.error( + f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}" + ) return - distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + distributed_state = ( + PartialState() + ) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here # unwrap nextdit and gemma2_model nextdit = accelerator.unwrap_model(nextdit) @@ -163,7 +184,9 @@ def sample_images( rng_state = torch.get_rng_state() cuda_rng_state = None try: - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + cuda_rng_state = ( + torch.cuda.get_rng_state() if torch.cuda.is_available() else None + ) except Exception: pass @@ -194,7 +217,9 @@ def sample_images( for i in range(distributed_state.num_processes): per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + with distributed_state.split_between_processes( + per_process_prompts + ) as prompt_dict_lists: # TODO: batch prompts together with buckets of image sizes for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): sample_image_inference( @@ -289,7 +314,9 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + negative_prompt = negative_prompt.replace( + prompt_replacement[0], prompt_replacement[1] + ) if negative_prompt is None: negative_prompt = "" @@ -314,17 +341,26 @@ def sample_image_inference( gemma2_conds = sample_prompts_gemma2_outputs[prompt] logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") - if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + if ( + sample_prompts_gemma2_outputs + and negative_prompt in sample_prompts_gemma2_outputs + ): neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] - logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + logger.info( + f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}" + ) # Load sample prompts from Gemma 2 if gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(prompt) - gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, [gemma2_model], tokens_and_masks + ) tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) - neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + neg_gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, [gemma2_model], tokens_and_masks + ) # Unpack Gemma2 outputs gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds @@ -340,10 +376,18 @@ def sample_image_inference( ) # Stack conditioning - cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device) - cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device) - uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device) - uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device) + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to( + accelerator.device + ) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to( + accelerator.device + ) # sample image weight_dtype = vae.dtype # TOFO give dtype as argument @@ -362,7 +406,9 @@ def sample_image_inference( noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) - timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps=sample_steps + ) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -422,7 +468,9 @@ def sample_image_inference( import wandb # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + wandb_tracker.log( + {f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False + ) # positive prompt as a caption vae.to(org_vae_device) clean_memory_on_device(accelerator.device) @@ -437,7 +485,9 @@ def time_shift(mu: float, sigma: float, t: torch.Tensor): return t -def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]: +def get_lin_function( + x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15 +) -> Callable[[float], float]: """ Get linear function @@ -481,7 +531,9 @@ def get_schedule( # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len) + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)( + image_seq_len + ) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() @@ -520,9 +572,13 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -532,7 +588,9 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -593,7 +651,9 @@ def denoise( # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device) + current_timestep = current_timestep * torch.ones( + img.shape[0], device=img.device + ) noise_pred_cond = model( img, @@ -610,12 +670,20 @@ def denoise( cap_feats=neg_txt, # Gemma2的hidden states作为caption features cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask ) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) # apply normalization after classifier-free guidance if float(renorm_cfg) > 0.0: - cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True) + cond_norm = torch.linalg.vector_norm( + noise_pred_cond, + dim=tuple(range(1, len(noise_pred_cond.shape))), + keepdim=True, + ) max_new_norm = cond_norm * float(renorm_cfg) - noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True) + noise_norm = torch.linalg.vector_norm( + noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True + ) if noise_norm >= max_new_norm: noise_pred = noise_pred * (max_new_norm / noise_norm) else: @@ -640,7 +708,11 @@ def denoise( # region train def get_sigmas( - noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32 + noise_scheduler: FlowMatchEulerDiscreteScheduler, + timesteps: Tensor, + device: torch.device, + n_dim=4, + dtype=torch.float32, ) -> Tensor: """ Get sigmas for timesteps @@ -667,7 +739,11 @@ def get_sigmas( def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, ): """ Compute the density for sampling the timesteps when doing SD3 training. @@ -688,7 +764,9 @@ def compute_density_for_timestep_sampling( """ if weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.normal( + mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu" + ) u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": u = torch.rand(size=(batch_size,), device="cpu") @@ -722,7 +800,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor return weighting -def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]: +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. @@ -753,27 +833,27 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + logits_norm = ( + logits_norm * args.sigmoid_scale + ) # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "nextdit_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) + t = torch.rand((bsz,), device=device) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16 + t = time_shift(mu, 1.0, t) - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -788,8 +868,10 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d timesteps = noise_scheduler.timesteps[indices].to(device=device) # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + sigmas = get_sigmas( + noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype + ) + noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -821,7 +903,9 @@ def apply_model_prediction_type( # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) return model_pred, weighting @@ -863,15 +947,27 @@ def save_models( def save_lumina_model_on_train_end( - args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + lumina: lumina_models.NextDiT, ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2" + None, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", ) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) - train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + train_util.save_sd_model_on_train_end_common( + args, True, True, epoch, global_step, sd_saver, None + ) # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている @@ -901,7 +997,15 @@ def save_lumina_model_on_epoch_end_or_stepwise( """ def sd_saver(ckpt_file: str, epoch_no: int, global_step: int): - sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + sai_metadata = train_util.get_sai_model_spec( + {}, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", + ) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( @@ -927,7 +1031,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): type=str, help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", ) - parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--ae", + type=str, + help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--gemma2_max_token_length", type=int, diff --git a/lumina_train.py b/lumina_train.py new file mode 100644 index 00000000..330d0093 --- /dev/null +++ b/lumina_train.py @@ -0,0 +1,953 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import ( + deepspeed_utils, + lumina_train_util, + lumina_util, + strategy_base, + strategy_lumina, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + # assert ( + # args.blocks_to_swap is None or args.blocks_to_swap == 0 + # ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, args.masked_loss, True) + ) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = ( + config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + ) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = ( + train_dataset_group if args.max_data_loader_n_workers == 0 else None + ) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + False, + ) + ) + strategy_base.TokenizeStrategy.set_strategy( + strategy_lumina.LuminaTokenizeStrategy() + ) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = lumina_util.load_ae( + args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.gemma2_max_token_length is None: + gemma2_max_token_length = 256 + else: + gemma2_max_token_length = args.gemma2_max_token_length + + lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy( + gemma2_max_token_length + ) + strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy) + + # load gemma2 for caching text encoder outputs + gemma2 = lumina_util.load_gemma2( + args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + gemma2.eval() + gemma2.requires_grad_(False) + + text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + gemma2.to(accelerator.device) + + text_encoder_caching_strategy = ( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + text_encoder_caching_strategy + ) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info( + f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" + ) + + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = lumina_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = ( + text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + tokens_and_masks, + ) + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + gemma2 = None + clean_memory_on_device(accelerator.device) + + # load lumina + nextdit = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + torch.device("cpu"), + disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, + ) + + if args.gradient_checkpointing: + nextdit.enable_gradient_checkpointing( + cpu_offload=args.cpu_offload_checkpointing + ) + + nextdit.requires_grad_(True) + + # block swap + + # backward compatibility + # if args.blocks_to_swap is None: + # blocks_to_swap = args.double_blocks_to_swap or 0 + # if args.single_blocks_to_swap is not None: + # blocks_to_swap += args.single_blocks_to_swap // 2 + # if blocks_to_swap > 0: + # logger.warning( + # "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + # " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + # ) + # logger.info( + # f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + # ) + # args.blocks_to_swap = blocks_to_swap + # del blocks_to_swap + + # is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + # if is_swapping_blocks: + # # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # # This idea is based on 2kpr's great work. Thank you! + # logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + # flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(nextdit) + name_and_params = list(nextdit.named_parameters()) + # single param group for now + params_to_optimize.append( + {"params": [p for _, p in name_and_params], "lr": args.learning_rate} + ) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(nextdit.named_parameters()) + assert len(named_parameters) == len( + group["params"] + ), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info( + f"using {len(optimizers)} optimizers for blockwise fused optimizers" + ) + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError( + "Schedule-free optimizer is not supported with blockwise fused optimizers" + ) + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer( + args, trainable_params=params_to_optimize + ) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( + optimizer, args + ) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min( + args.max_data_loader_n_workers, os.cpu_count() + ) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) + / accelerator.num_processes + / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [ + train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + for optimizer in optimizers + ] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + gemma2.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + nextdit = accelerator.prepare( + nextdit, device_placement=[not is_swapping_blocks] + ) + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( + accelerator.device + ) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook( + create_grad_hook(param_name, param_group) + ) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_( + parameter, args.max_grad_norm + ) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = ( + math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + ) + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print( + f" num examples / サンプル数: {train_dataset_group.num_train_images}" + ) + accelerator.print( + f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}" + ) + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + accelerator.print( + f" total optimization steps / 学習ステップ数: {args.max_train_steps}" + ) + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=args.discrete_flow_shift + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + 0, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = { + i: 0 for i in range(len(optimizers)) + } # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to( + accelerator.device, dtype=weight_dtype + ) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ + ids.to(accelerator.device) + for ids in batch["input_ids_list"] + ] + text_encoder_conds = text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [ + c.to(weight_dtype) for c in text_encoder_conds + ] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = ( + lumina_train_util.get_noisy_model_input_and_timesteps( + args, + noise_scheduler_copy, + latents, + noise, + accelerator.device, + weight_dtype, + ) + ) + # call model + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = nextdit( + x=img, # image latents (B, C, H, W) + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask.to( + dtype=torch.int32 + ), # Gemma2的attention mask + ) + # apply model prediction type + model_pred, weighting = lumina_train_util.apply_model_prediction_type( + args, model_pred, noisy_model_input, sigmas + ) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed( + args, timesteps, noise_scheduler + ) + loss = train_util.conditional_loss( + model_pred.float(), target.float(), args.loss_type, "none", huber_c + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ( + "alpha_masks" in batch and batch["alpha_masks"] is not None + ): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + None, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + + # 指定ステップごとにモデルを保存 + if ( + args.save_every_n_steps is not None + and global_step % args.save_every_n_steps == 0 + ): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs( + logs, lr_scheduler, args.optimizer_type, including_unet=True + ) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + + lumina_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + nextdit = accelerator.unwrap_model(nextdit) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + lumina_train_util.save_lumina_model_on_train_end( + args, save_dtype, epoch, global_step, nextdit + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/lumina_train_network.py b/lumina_train_network.py index 0fd4da6b..5f20c014 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -15,7 +15,6 @@ from accelerate import Accelerator import train_network from library import ( lumina_models, - flux_train_utils, lumina_util, lumina_train_util, sd3_train_utils, @@ -250,36 +249,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ): assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = lumina_train_util.compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - # Add noise according to flow matching. - # zt = (1 - texp) * x + texp * z1 - # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents` - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) - noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -310,7 +283,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss target = latents - noise @@ -336,7 +309,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model_pred_prior = lumina_util.unpack_latents( # model_pred_prior, packed_latent_height, packed_latent_width # ) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + model_pred_prior, _ = lumina_train_util.apply_model_prediction_type( args, model_pred_prior, noisy_model_input[diff_output_pr_indices], From a1a5627b13d0ebf182710ea0cea5e97ab2f6d580 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 26 Feb 2025 11:35:38 +0800 Subject: [PATCH 24/73] fix shift --- library/lumina_train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index ca039167..11dd3feb 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -848,7 +848,7 @@ def get_noisy_model_input_and_timesteps( noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) - mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16 + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) t = time_shift(mu, 1.0, t) timesteps = t * 1000.0 From 70403f6977471e543f4ffa1b82edc0b0a4d77a3b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Feb 2025 23:33:50 -0500 Subject: [PATCH 25/73] fix cache text encoder outputs if not using disk. small cleanup/alignment --- library/strategy_lumina.py | 43 +++++++++++++++++++------------------- train_network.py | 1 - 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e65423..74f15cec 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -196,6 +196,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) input_ids = data["input_ids"] return [hidden_state, input_ids, attention_mask] + @torch.no_grad() def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, @@ -222,23 +223,21 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) tokens, attention_masks, weights_list = ( tokenize_strategy.tokenize_with_weights(captions) ) - with torch.no_grad(): - hidden_state, input_ids, attention_masks = ( - text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, - models, - (tokens, attention_masks), - weights_list, - ) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, ) + ) else: tokens = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state, input_ids, attention_masks = ( - text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens - ) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens ) + ) if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() @@ -247,14 +246,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) attention_mask = attention_masks.cpu().numpy() # (B, S) input_ids = input_ids.cpu().numpy() # (B, S) + for i, info in enumerate(batch): hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] - assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}" - if self.cache_to_disk: + assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, @@ -338,21 +337,21 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): # TODO remove circular dependency for ImageInfo def cache_batch_latents( self, - vae, - image_infos: List, + model, + batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool, ): - encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") - vae_device = vae.device - vae_dtype = vae.dtype + encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu") + vae_device = model.device + vae_dtype = model.dtype self._default_cache_batch_latents( encode_by_vae, vae_device, vae_dtype, - image_infos, + batch, flip_aug, alpha_mask, random_crop, @@ -360,4 +359,4 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): ) if not train_util.HIGH_VRAM: - train_util.clean_memory_on_device(vae.device) + train_util.clean_memory_on_device(model.device) diff --git a/train_network.py b/train_network.py index ff62f46a..b4b0d42d 100644 --- a/train_network.py +++ b/train_network.py @@ -1282,7 +1282,6 @@ class NetworkTrainer: # For --sample_at_first optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) - progress_bar.unpause() # Reset progress bar to before sampling images optimizer_train_fn() is_tracking = len(accelerator.trackers) > 0 if is_tracking: From 542f980443feadee0cbab2beeae3f9b3891a3058 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 00:00:20 -0500 Subject: [PATCH 26/73] Fix sample norms in batches --- library/lumina_train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..a95da382 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -680,12 +680,14 @@ def denoise( dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True, ) - max_new_norm = cond_norm * float(renorm_cfg) - noise_norm = torch.linalg.vector_norm( + max_new_norms = cond_norm * float(renorm_cfg) + noise_norms = torch.linalg.vector_norm( noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True ) - if noise_norm >= max_new_norm: - noise_pred = noise_pred * (max_new_norm / noise_norm) + # Iterate through batch + for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred): + if noise_norm >= max_new_norm: + noise = noise * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond From 0886d976f1d3cca531bc068a5b1a0e54555dc20c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 02:31:50 -0500 Subject: [PATCH 27/73] Add block swap --- library/lumina_models.py | 65 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 1a441a69..c00ca88d 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -29,6 +29,8 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F +from library import custom_offloading_utils + try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1066,8 +1068,16 @@ class NextDiT(nn.Module): x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) - for layer in self.layers: - x = layer(x, mask, freqs_cis, t) + if not self.blocks_to_swap: + for layer in self.layers: + x = layer(x, mask, freqs_cis, t) + else: + for block_idx, layer in enumerate(self.layers): + self.offloader_main.wait_for_block(block_idx) + + x = layer(x, mask, freqs_cis, t) + + self.offloader_main.submit_move_blocks(self.layers, block_idx) x = self.final_layer(x, t) x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) @@ -1184,6 +1194,57 @@ class NextDiT(nn.Module): def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: return list(self.layers) + def enable_block_swap(self, num_blocks: int, device: torch.device): + """ + Enable block swapping to reduce memory usage during inference. + + Args: + num_blocks (int): Number of blocks to swap between CPU and device + device (torch.device): Device to use for computation + """ + self.blocks_to_swap = num_blocks + + # Calculate how many blocks to swap from main layers + num_main_blocks_to_swap = min(num_blocks, self.layers) + + assert num_main_blocks_to_swap <= len(self.layers) - 2, ( + f"Cannot swap more than {len(self.layers) - 2} main blocks. " + f"Requested {num_main_blocks_to_swap} blocks." + ) + + self.offloader_main = custom_offloading_utils.ModelOffloader( + self.layers, len(self.layers), num_main_blocks_to_swap, device + ) + + print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + """ + Move the model to the device except for blocks that will be swapped. + This reduces temporary memory usage during model loading. + + Args: + device (torch.device): Device to move the model to + """ + if self.blocks_to_swap: + save_layers = self.layers + self.layers = None + + self.to(device) + + self.layers = save_layers + else: + self.to(device) + + def prepare_block_swap_before_forward(self): + """ + Prepare blocks for swapping before forward pass. + """ + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + self.offloader_main.prepare_block_devices_before_forward(self.layers) + ############################################################################# # NextDiT Configs # From ce2610d29b399c8353686f50bf1973457a133153 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 02:47:04 -0500 Subject: [PATCH 28/73] Change system prompt to inject Prompt Start special token --- library/lumina_train_util.py | 5 +++-- library/strategy_lumina.py | 3 ++- library/train_util.py | 6 ++++-- lumina_train_network.py | 9 ++++++--- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..bfc470a9 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -330,11 +330,12 @@ def sample_image_inference( logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" # Apply system prompt to prompts prompt = system_prompt + prompt - negative_prompt = system_prompt + negative_prompt + negative_prompt = negative_prompt # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e65423..275e290f 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -216,7 +216,8 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) - captions = [info.system_prompt or "" + info.caption for info in batch] + system_prompt_special_token = "" + captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch] if self.is_weighted: tokens, attention_masks, weights_list = ( diff --git a/library/train_util.py b/library/train_util.py index 0c057bd1..34b98f89 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1692,7 +1692,8 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - system_prompt = subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else "" caption = self.process_caption(subset, image_info.caption) input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: @@ -2091,7 +2092,8 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += num_repeats * len(img_paths) - system_prompt = self.system_prompt or subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else "" for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) if size is not None: diff --git a/lumina_train_network.py b/lumina_train_network.py index 5f20c014..c9ef5f02 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -155,7 +155,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): @@ -164,8 +165,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] - for prompt in prompts: - prompt = system_prompt + prompt + for i, prompt in enumerate(prompts): + # Add system prompt only to positive prompt + if i == 0: + prompt = system_prompt + prompt if prompt in sample_prompts_te_outputs: continue From 42fe22f5a25e950545b81e53c13b0c1c804d6e46 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 03:21:24 -0500 Subject: [PATCH 29/73] Enable block swap for Lumina --- library/lumina_models.py | 7 +++---- lumina_train_network.py | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index c00ca88d..020320b0 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1205,15 +1205,14 @@ class NextDiT(nn.Module): self.blocks_to_swap = num_blocks # Calculate how many blocks to swap from main layers - num_main_blocks_to_swap = min(num_blocks, self.layers) - assert num_main_blocks_to_swap <= len(self.layers) - 2, ( + assert num_blocks <= len(self.layers) - 2, ( f"Cannot swap more than {len(self.layers) - 2} main blocks. " - f"Requested {num_main_blocks_to_swap} blocks." + f"Requested {num_blocks} blocks." ) self.offloader_main = custom_offloading_utils.ModelOffloader( - self.layers, len(self.layers), num_main_blocks_to_swap, device + self.layers, len(self.layers), num_blocks, device ) print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.") diff --git a/lumina_train_network.py b/lumina_train_network.py index 5f20c014..44c3f32f 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -73,10 +73,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) model.to(torch.float8_e4m3fn) - # if args.blocks_to_swap: - # logger.info(f'Enabling block swap: {args.blocks_to_swap}') - # model.enable_block_swap(args.blocks_to_swap, accelerator.device) - # self.is_swapping_blocks = True + if args.blocks_to_swap: + logger.info(f'Enabling block swap: {args.blocks_to_swap}') + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + self.is_swapping_blocks = True gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") gemma2.eval() From 9647f1e32485444facb8a5be5eb77dbac797dc71 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 20:36:36 -0500 Subject: [PATCH 30/73] Fix validation block swap. Add custom offloading tests --- library/custom_offloading_utils.py | 30 +- library/flux_models.py | 8 +- library/lumina_models.py | 19 +- library/sd3_models.py | 4 +- library/strategy_lumina.py | 2 +- lumina_train_network.py | 7 +- tests/test_custom_offloading_utils.py | 408 ++++++++++++++++++++++++++ 7 files changed, 446 insertions(+), 32 deletions(-) create mode 100644 tests/test_custom_offloading_utils.py diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 84c2b743..55ff08b6 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -1,6 +1,6 @@ from concurrent.futures import ThreadPoolExecutor import time -from typing import Optional +from typing import Optional, Union, Callable, Tuple import torch import torch.nn as nn @@ -19,7 +19,7 @@ def synchronize_device(device: torch.device): def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - weight_swap_jobs = [] + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): @@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - stream = torch.cuda.Stream() + stream = torch.Stream(device="cuda") with torch.cuda.stream(stream): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: @@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l """ assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - weight_swap_jobs = [] + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + # device to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - synchronize_device() + synchronize_device(device) # cpu to device for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view - synchronize_device() + synchronize_device(device) def weighs_to_device(layer: nn.Module, device: torch.device): @@ -148,13 +149,16 @@ class Offloader: print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") +# Gradient tensors +_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor] + class ModelOffloader(Offloader): """ supports forward offloading """ - def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - super().__init__(num_blocks, blocks_to_swap, device, debug) + def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(len(blocks), blocks_to_swap, device, debug) # register backward hooks self.remove_handles = [] @@ -168,7 +172,7 @@ class ModelOffloader(Offloader): for handle in self.remove_handles: handle.remove() - def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: # -1 for 0-based index num_blocks_propagated = self.num_blocks - block_index - 1 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap @@ -182,7 +186,7 @@ class ModelOffloader(Offloader): block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated block_idx_to_wait = block_index - 1 - def backward_hook(module, grad_input, grad_output): + def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t): if self.debug: print(f"Backward hook for block {block_index}") @@ -194,7 +198,7 @@ class ModelOffloader(Offloader): return backward_hook - def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return @@ -207,7 +211,7 @@ class ModelOffloader(Offloader): for b in blocks[self.num_blocks - self.blocks_to_swap :]: b.to(self.device) # move block to device first - weighs_to_device(b, "cpu") # make sure weights are on cpu + weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu synchronize_device(self.device) clean_memory_on_device(self.device) @@ -217,7 +221,7 @@ class ModelOffloader(Offloader): return self._wait_blocks_move(block_idx) - def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return if block_idx >= self.blocks_to_swap: diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..b00bdae2 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1219,10 +1219,10 @@ class ControlNetFlux(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1233,8 +1233,8 @@ class ControlNetFlux(nn.Module): if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks - self.double_blocks = None - self.single_blocks = None + self.double_blocks = nn.ModuleList() + self.single_blocks = nn.ModuleList() self.to(device) diff --git a/library/lumina_models.py b/library/lumina_models.py index 020320b0..2d4c6527 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1194,7 +1194,7 @@ class NextDiT(nn.Module): def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: return list(self.layers) - def enable_block_swap(self, num_blocks: int, device: torch.device): + def enable_block_swap(self, blocks_to_swap: int, device: torch.device): """ Enable block swapping to reduce memory usage during inference. @@ -1202,20 +1202,18 @@ class NextDiT(nn.Module): num_blocks (int): Number of blocks to swap between CPU and device device (torch.device): Device to use for computation """ - self.blocks_to_swap = num_blocks + self.blocks_to_swap = blocks_to_swap # Calculate how many blocks to swap from main layers - assert num_blocks <= len(self.layers) - 2, ( + assert blocks_to_swap <= len(self.layers) - 2, ( f"Cannot swap more than {len(self.layers) - 2} main blocks. " - f"Requested {num_blocks} blocks." + f"Requested {blocks_to_swap} blocks." ) self.offloader_main = custom_offloading_utils.ModelOffloader( - self.layers, len(self.layers), num_blocks, device + self.layers, blocks_to_swap, device, debug=False ) - - print(f"NextDiT: Block swap enabled. Swapping {num_blocks} blocks.") def move_to_device_except_swap_blocks(self, device: torch.device): """ @@ -1227,13 +1225,12 @@ class NextDiT(nn.Module): """ if self.blocks_to_swap: save_layers = self.layers - self.layers = None + self.layers = nn.ModuleList([]) - self.to(device) + self.to(device) + if self.blocks_to_swap: self.layers = save_layers - else: - self.to(device) def prepare_block_swap_before_forward(self): """ diff --git a/library/sd3_models.py b/library/sd3_models.py index e4a93186..996f8192 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1080,7 +1080,7 @@ class MMDiT(nn.Module): ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." self.offloader = custom_offloading_utils.ModelOffloader( - self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + self.joint_blocks, self.blocks_to_swap, device # , debug=True ) print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") @@ -1088,7 +1088,7 @@ class MMDiT(nn.Module): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_blocks = self.joint_blocks - self.joint_blocks = None + self.joint_blocks = nn.ModuleList() self.to(device) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e65423..714326ad 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -208,7 +208,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy models (List[Any]): Text encoders text_encoding_strategy (LuminaTextEncodingStrategy): - infos (List): List of image_info + infos (List): List of ImageInfo Returns: None diff --git a/lumina_train_network.py b/lumina_train_network.py index 44c3f32f..3e003a92 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -74,7 +74,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): model.to(torch.float8_e4m3fn) if args.blocks_to_swap: - logger.info(f'Enabling block swap: {args.blocks_to_swap}') + logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}') model.enable_block_swap(args.blocks_to_swap, accelerator.device) self.is_swapping_blocks = True @@ -361,6 +361,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return nextdit + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/tests/test_custom_offloading_utils.py b/tests/test_custom_offloading_utils.py new file mode 100644 index 00000000..5fa40b76 --- /dev/null +++ b/tests/test_custom_offloading_utils.py @@ -0,0 +1,408 @@ +import pytest +import torch +import torch.nn as nn +from unittest.mock import patch, MagicMock + +from library.custom_offloading_utils import ( + synchronize_device, + swap_weight_devices_cuda, + swap_weight_devices_no_cuda, + weighs_to_device, + Offloader, + ModelOffloader +) + +class TransformerBlock(nn.Module): + def __init__(self, block_idx: int): + super().__init__() + self.block_idx = block_idx + self.linear1 = nn.Linear(10, 5) + self.linear2 = nn.Linear(5, 10) + self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10)) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + x = self.seq(x) + return x + + +class SimpleModel(nn.Module): + def __init__(self, num_blocks=16): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(i) + for i in range(num_blocks)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + +# Device Synchronization Tests +@patch('torch.cuda.synchronize') +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_synchronize(mock_cuda_sync): + device = torch.device('cuda') + synchronize_device(device) + mock_cuda_sync.assert_called_once() + +@patch('torch.xpu.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available") +def test_xpu_synchronize(mock_xpu_sync): + device = torch.device('xpu') + synchronize_device(device) + mock_xpu_sync.assert_called_once() + +@patch('torch.mps.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available") +def test_mps_synchronize(mock_mps_sync): + device = torch.device('mps') + synchronize_device(device) + mock_mps_sync.assert_called_once() + + +# Weights to Device Tests +def test_weights_to_device(): + # Create a simple model with weights + model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 2) + ) + + # Start with CPU tensors + device = torch.device('cpu') + for module in model.modules(): + if hasattr(module, "weight") and module.weight is not None: + assert module.weight.device == device + + # Move to mock CUDA device + mock_device = torch.device('cuda') + with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)): + weighs_to_device(model, mock_device) + + # Since we mocked the to() function, we can only verify modules were processed + # but can't check actual device movement + + +# Swap Weight Devices Tests +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_swap_weight_devices_cuda(): + device = torch.device('cuda') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + # Move layer to CUDA to move to CPU + layer_to_cpu.to(device) + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + + assert layer_to_cpu.device.type == 'cpu' + assert layer_to_cuda.device.type == 'cuda' + + + +@patch('library.custom_offloading_utils.synchronize_device') +def test_swap_weight_devices_no_cuda(mock_sync_device): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda) + + # Verify synchronize_device was called twice + assert mock_sync_device.call_count == 2 + + +# Offloader Tests +@pytest.fixture +def offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return Offloader( + num_blocks=4, + blocks_to_swap=2, + device=device, + debug=False + ) + + +def test_offloader_init(offloader): + assert offloader.num_blocks == 4 + assert offloader.blocks_to_swap == 2 + assert hasattr(offloader, 'thread_pool') + assert offloader.futures == {} + assert offloader.cuda_available == (offloader.device.type == 'cuda') + + +@patch('library.custom_offloading_utils.swap_weight_devices_cuda') +@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda') +def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader): + block_to_cpu = SimpleModel() + block_to_cuda = SimpleModel() + + # Force test for CUDA device + offloader.cuda_available = True + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_no_cuda.assert_not_called() + + # Reset mocks + mock_cuda.reset_mock() + mock_no_cuda.reset_mock() + + # Force test for non-CUDA device + offloader.cuda_available = False + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_cuda.assert_not_called() + + +@patch('library.custom_offloading_utils.Offloader.swap_weight_devices') +def test_submit_move_blocks(mock_swap, offloader): + blocks = [SimpleModel() for _ in range(4)] + block_idx_to_cpu = 0 + block_idx_to_cuda = 2 + + # Mock the thread pool to execute synchronously + future = MagicMock() + future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda) + offloader.thread_pool.submit = MagicMock(return_value=future) + + offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + # Check that the future is stored with the correct key + assert block_idx_to_cuda in offloader.futures + + +def test_wait_blocks_move(offloader): + block_idx = 2 + + # Test with no future for the block + offloader._wait_blocks_move(block_idx) # Should not raise + + # Create a fake future and test waiting + future = MagicMock() + future.result.return_value = (0, block_idx) + offloader.futures[block_idx] = future + + offloader._wait_blocks_move(block_idx) + + # Check that the future was removed + assert block_idx not in offloader.futures + future.result.assert_called_once() + + +# ModelOffloader Tests +@pytest.fixture +def model_offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + return ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + +def test_model_offloader_init(model_offloader): + assert model_offloader.num_blocks == 4 + assert model_offloader.blocks_to_swap == 2 + assert hasattr(model_offloader, 'thread_pool') + assert model_offloader.futures == {} + assert len(model_offloader.remove_handles) > 0 # Should have registered hooks + + +def test_create_backward_hook(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test hook creation for swapping case (block 0) + hook_swap = model_offloader.create_backward_hook(blocks, 0) + assert hook_swap is None + + # Test hook creation for waiting case (block 1) + hook_wait = model_offloader.create_backward_hook(blocks, 1) + assert hook_wait is not None + + # Test hook creation for no action case (block 3) + hook_none = model_offloader.create_backward_hook(blocks, 3) + assert hook_none is None + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_backward_hook_execution(mock_wait, mock_submit): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + model = SimpleModel(4) + blocks = model.blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test swapping hook (block 1) + hook_swap = model_offloader.create_backward_hook(blocks, 1) + assert hook_swap is not None + hook_swap(model, torch.zeros(1), torch.zeros(1)) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test waiting hook (block 2) + hook_wait = model_offloader.create_backward_hook(blocks, 2) + assert hook_wait is not None + hook_wait(model, torch.zeros(1), torch.zeros(1)) + assert mock_wait.call_count == 2 + + +@patch('library.custom_offloading_utils.weighs_to_device') +@patch('library.custom_offloading_utils.synchronize_device') +@patch('library.custom_offloading_utils.clean_memory_on_device') +def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader): + model = SimpleModel(4) + blocks = model.blocks + + with patch.object(nn.Module, 'to'): + model_offloader.prepare_block_devices_before_forward(blocks) + + # Check that weighs_to_device was called for each block + assert mock_weights_to_device.call_count == 4 + + # Check that synchronize_device and clean_memory_on_device were called + mock_sync.assert_called_once_with(model_offloader.device) + mock_clean.assert_called_once_with(model_offloader.device) + + +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_wait_for_block(mock_wait, model_offloader): + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.wait_for_block(1) + mock_wait.assert_not_called() + + # Test with blocks_to_swap=2 + model_offloader.blocks_to_swap = 2 + block_idx = 1 + model_offloader.wait_for_block(block_idx) + mock_wait.assert_called_once_with(block_idx) + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +def test_submit_move_blocks(mock_submit, model_offloader): + model = SimpleModel() + blocks = model.blocks + + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.submit_move_blocks(blocks, 1) + mock_submit.assert_not_called() + + mock_submit.reset_mock() + model_offloader.blocks_to_swap = 2 + + # Test within swap range + block_idx = 1 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test outside swap range + block_idx = 3 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_not_called() + + +# Integration test for offloading in a realistic scenario +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_offloading_integration(): + device = torch.device('cuda') + # Create a mini model with 4 blocks + model = SimpleModel(5) + model.to(device) + blocks = model.blocks + + # Initialize model offloader + offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=2, + device=device, + debug=True + ) + + # Prepare blocks for forward pass + offloader.prepare_block_devices_before_forward(blocks) + + # Simulate forward pass with offloading + input_tensor = torch.randn(1, 10, device=device) + x = input_tensor + + for i, block in enumerate(blocks): + # Wait for the current block to be ready + offloader.wait_for_block(i) + + # Process through the block + x = block(x) + + # Schedule moving weights for future blocks + offloader.submit_move_blocks(blocks, i) + + # Verify we get a valid output + assert x.shape == (1, 10) + assert not torch.isnan(x).any() + + +# Error handling tests +def test_offloader_assertion_error(): + with pytest.raises(AssertionError): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = nn.Linear(10, 5) # Different class + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + +if __name__ == "__main__": + # Run all tests when file is executed directly + import sys + + # Configure pytest command line arguments + pytest_args = [ + "-v", # Verbose output + "--color=yes", # Colored output + __file__, # Run tests in this file + ] + + # Add optional arguments from command line + if len(sys.argv) > 1: + pytest_args.extend(sys.argv[1:]) + + # Print info about test execution + print(f"Running tests with PyTorch {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + + # Run the tests + sys.exit(pytest.main(pytest_args)) From d6f7e2e20cfe91eb0c7a5f4c277107f7b699d97f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 28 Feb 2025 14:08:27 -0500 Subject: [PATCH 31/73] Fix block swap for sample images --- library/flux_train_utils.py | 1 - library/lumina_train_util.py | 3 ++- lumina_train_network.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..c6d2baeb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -317,7 +317,6 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..e008b3ce 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -604,7 +604,6 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps - def denoise( scheduler, model: lumina_models.NextDiT, @@ -648,6 +647,7 @@ def denoise( """ for i, t in enumerate(tqdm(timesteps)): + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -700,6 +700,7 @@ def denoise( noise_pred = -noise_pred img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + model.prepare_block_swap_before_forward() return img diff --git a/lumina_train_network.py b/lumina_train_network.py index 3e003a92..60c39c20 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -367,6 +367,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) From 1bba7acd9ac42ef5a654cadf47356d20d407ce82 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 28 Feb 2025 14:11:53 -0500 Subject: [PATCH 32/73] Add block swap in sample image timestep loop --- library/lumina_train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index e008b3ce..0be81df9 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -647,6 +647,7 @@ def denoise( """ for i, t in enumerate(tqdm(timesteps)): + model.prepare_block_swap_before_forward() # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps From a2daa870074310ba2415da993016f0779c8b56e2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 28 Feb 2025 14:22:39 -0500 Subject: [PATCH 33/73] Add block swap for uncond (neg) for sample images --- library/lumina_train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 0be81df9..933a4eda 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -665,6 +665,7 @@ def denoise( # compute whether to apply classifier-free guidance based on current timestep if current_timestep[0] < cfg_trunc_ratio: + model.prepare_block_swap_before_forward() noise_pred_uncond = model( img, current_timestep, From cad182d29a2f3ad3ed7550b258025f3243981464 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 28 Feb 2025 18:30:16 -0500 Subject: [PATCH 34/73] fix torch compile/dynamo for Gemma2 --- library/strategy_lumina.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e65423..b4c94106 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -97,7 +97,8 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): hidden_states, input_ids, attention_masks """ text_encoder = models[0] - assert isinstance(text_encoder, Gemma2Model) + # Check model or torch dynamo OptimizedModule + assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}" input_ids, attention_masks = tokens outputs = text_encoder( From a69884a2090076a4bf7f4dedf4cc6aa82789e3bc Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sat, 1 Mar 2025 20:37:45 -0500 Subject: [PATCH 35/73] Add Sage Attention for Lumina --- library/lumina_models.py | 91 +++++++++++++++++++++++++++++++++--- library/lumina_train_util.py | 5 ++ library/lumina_util.py | 3 +- lumina_train_network.py | 1 + 4 files changed, 93 insertions(+), 7 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 1a441a69..00ac16d5 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -36,6 +36,11 @@ except: # flash_attn may not be available but it is not required pass +try: + from sageattention import sageattn +except: + pass + try: from apex.normalization import FusedRMSNorm as RMSNorm except: @@ -271,6 +276,7 @@ class JointAttention(nn.Module): n_kv_heads: Optional[int], qk_norm: bool, use_flash_attn=False, + use_sage_attn=False, ): """ Initialize the Attention module. @@ -310,13 +316,20 @@ class JointAttention(nn.Module): self.q_norm = self.k_norm = nn.Identity() self.use_flash_attn = use_flash_attn + self.use_sage_attn = use_sage_attn - # self.attention_processor = xformers.ops.memory_efficient_attention - self.attention_processor = F.scaled_dot_product_attention + if use_sage_attn : + self.attention_processor = self.sage_attn + else: + # self.attention_processor = xformers.ops.memory_efficient_attention + self.attention_processor = F.scaled_dot_product_attention def set_attention_processor(self, attention_processor): self.attention_processor = attention_processor + def get_attention_processor(self): + return self.attention_processor + def forward( self, x: Tensor, @@ -352,7 +365,15 @@ class JointAttention(nn.Module): softmax_scale = math.sqrt(1 / self.head_dim) - if self.use_flash_attn: + if self.use_sage_attn: + # Handle GQA (Grouped Query Attention) if needed + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale) + elif self.use_flash_attn: output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) else: n_rep = self.n_local_heads // self.n_local_kv_heads @@ -428,6 +449,63 @@ class JointAttention(nn.Module): (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float): + try: + bsz = q.shape[0] + seqlen = q.shape[1] + + # Transpose tensors to match SageAttention's expected format (HND layout) + q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + + # Handle masking for SageAttention + # We need to filter out masked positions - this approach handles variable sequence lengths + outputs = [] + for b in range(bsz): + # Find valid token positions from the mask + valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1) + if valid_indices.numel() == 0: + # If all tokens are masked, create a zero output + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + else: + # Extract only valid tokens for this batch + batch_q = q_transposed[b, :, valid_indices, :] + batch_k = k_transposed[b, :, valid_indices, :] + batch_v = v_transposed[b, :, valid_indices, :] + + # Run SageAttention on valid tokens only + batch_output_valid = sageattn( + batch_q.unsqueeze(0), # Add batch dimension back + batch_k.unsqueeze(0), + batch_v.unsqueeze(0), + tensor_layout="HND", + is_causal=False, + sm_scale=softmax_scale + ) + + # Create output tensor with zeros for masked positions + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + # Place valid outputs back in the right positions + batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2) + + outputs.append(batch_output) + + # Stack batch outputs and reshape to expected format + output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim] + except NameError as e: + raise RuntimeError( + f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}" + ) + + return output + def flash_attn( self, q: Tensor, @@ -571,6 +649,7 @@ class JointTransformerBlock(GradientCheckpointMixin): qk_norm: bool, modulation=True, use_flash_attn=False, + use_sage_attn=False, ) -> None: """ Initialize a TransformerBlock. @@ -593,7 +672,7 @@ class JointTransformerBlock(GradientCheckpointMixin): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn) self.feed_forward = FeedForward( dim=dim, hidden_dim=4 * dim, @@ -764,6 +843,7 @@ class NextDiT(nn.Module): axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512], use_flash_attn=False, + use_sage_attn=False, ) -> None: """ Initialize the NextDiT model. @@ -817,7 +897,6 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=False, - use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -843,7 +922,6 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, - use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -865,6 +943,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, use_flash_attn=use_flash_attn, + use_sage_attn=use_sage_attn, ) for layer_id in range(n_layers) ] diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..d3a54a74 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1077,6 +1077,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): action="store_true", help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", ) + parser.add_argument( + "--use_sage_attn", + action="store_true", + help="Use Sage Attention for the model / モデルにSage Attentionを使用する", + ) parser.add_argument( "--system_prompt", type=str, diff --git a/library/lumina_util.py b/library/lumina_util.py index d9c89938..06f089d4 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -27,6 +27,7 @@ def load_lumina_model( device: torch.device, disable_mmap: bool = False, use_flash_attn: bool = False, + use_sage_attn: bool = False, ): """ Load the Lumina model from the checkpoint path. @@ -43,7 +44,7 @@ def load_lumina_model( """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) diff --git a/lumina_train_network.py b/lumina_train_network.py index 5f20c014..ed1f3aae 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -58,6 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, + use_sage_attn=args.use_sage_attn ) if args.fp8_base: From 5e45df722d434bd64b230f462cac632d5ea68c96 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Tue, 4 Mar 2025 08:07:33 +0800 Subject: [PATCH 36/73] update gemma2 train attention layer --- networks/lora_lumina.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 3f6c9b41..431c183d 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder From 1f22a94cfe55491cc708adfa881953db423a886f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 4 Mar 2025 02:21:05 -0500 Subject: [PATCH 37/73] Update embedder_dims, add more flexible caption extension --- library/lumina_models.py | 6 +- library/train_util.py | 39 ++++--- networks/lora_lumina.py | 235 ++++++++++++++++++++++----------------- 3 files changed, 159 insertions(+), 121 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index e00dcf96..2508cc7d 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -887,6 +887,9 @@ class NextDiT(nn.Module): ), ) + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + nn.init.zeros_(self.cap_embedder[1].bias) + self.context_refiner = nn.ModuleList( [ JointTransformerBlock( @@ -929,9 +932,6 @@ class NextDiT(nn.Module): ] ) - nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) - # nn.init.zeros_(self.cap_embedder[1].weight) - nn.init.zeros_(self.cap_embedder[1].bias) self.layers = nn.ModuleList( [ diff --git a/library/train_util.py b/library/train_util.py index 34b98f89..c07a4a73 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset): self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension - if self.caption_extension and not self.caption_extension.startswith("."): - self.caption_extension = "." + self.caption_extension + # if self.caption_extension and not self.caption_extension.startswith("."): + # self.caption_extension = "." + self.caption_extension self.cache_info = cache_info def __eq__(self, other) -> bool: @@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension, enable_wildcard): + def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name tokens = base_name.split("_") if len(tokens) >= 5: base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)] caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - if enable_wildcard: - caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 - else: - caption = lines[0].strip() - break + for base, cap_extension in cap_paths: + # check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt) + for cap_path in [base + cap_extension, base + "." + cap_extension]: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() + break + break return caption def load_dreambooth_dir(subset: DreamBoothSubset): diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 431c183d..f856d4e7 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -1,5 +1,5 @@ # temporary minimum implementation of LoRA -# FLUX doesn't have Conv2d, so we ignore it +# Lumina 2 does not have Conv2d, so ignore # TODO commonize with the original implementation # LoRA network module @@ -10,13 +10,11 @@ import math import os from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from transformers import CLIPTextModel -import numpy as np import torch -import re +from torch import Tensor, nn from library.utils import setup_logging -from library.sdxl_original_unet import SdxlUNet2DConditionModel setup_logging() import logging @@ -35,14 +33,14 @@ class LoRAModule(torch.nn.Module): def __init__( self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, + lora_name: str, + org_module: nn.Module, + multiplier: float =1.0, + lora_dim: int = 4, + alpha: Optional[float | int | Tensor] = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, split_dims: Optional[List[int]] = None, ): """ @@ -60,6 +58,9 @@ class LoRAModule(torch.nn.Module): in_dim = org_module.in_features out_dim = org_module.out_features + assert isinstance(in_dim, int) + assert isinstance(out_dim, int) + self.lora_dim = lora_dim self.split_dims = split_dims @@ -68,30 +69,31 @@ class LoRAModule(torch.nn.Module): kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False) - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) else: # conv2d not supported assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" # print(f"split_dims: {split_dims}") - self.lora_down = torch.nn.ModuleList( - [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + self.lora_down = nn.ModuleList( + [nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] ) - self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - for lora_down in self.lora_down: - torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) - for lora_up in self.lora_up: - torch.nn.init.zeros_(lora_up.weight) + self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + for lora_down in self.lora_down: + nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + nn.init.zeros_(lora_up.weight) + + if isinstance(alpha, Tensor): + alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える @@ -140,6 +142,9 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -152,9 +157,9 @@ class LoRAModule(torch.nn.Module): if self.rank_dropout is not None and self.training: masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] for i in range(len(lxs)): - if len(lx.size()) == 3: + if len(lxs[i].size()) == 3: masks[i] = masks[i].unsqueeze(1) - elif len(lx.size()) == 4: + elif len(lxs[i].size()) == 4: masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) lxs[i] = lxs[i] * masks[i] @@ -165,6 +170,9 @@ class LoRAModule(torch.nn.Module): lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale @@ -339,14 +347,14 @@ def create_network( if all([d is None for d in type_dims]): type_dims = None - # in_dims for embedders - in_dims = kwargs.get("in_dims", None) - if in_dims is not None: - in_dims = in_dims.strip() - if in_dims.startswith("[") and in_dims.endswith("]"): - in_dims = in_dims[1:-1] - in_dims = [int(d) for d in in_dims.split(",")] - assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)" + # embedder_dims for embedders + embedder_dims = kwargs.get("embedder_dims", None) + if embedder_dims is not None: + embedder_dims = embedder_dims.strip() + if embedder_dims.startswith("[") and embedder_dims.endswith("]"): + embedder_dims = embedder_dims[1:-1] + embedder_dims = [int(d) for d in embedder_dims.split(",")] + assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder)" # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) @@ -357,9 +365,9 @@ def create_network( module_dropout = float(module_dropout) # single or double blocks - train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner" if train_blocks is not None: - assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}" # split qkv split_qkv = kwargs.get("split_qkv", False) @@ -386,7 +394,7 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, type_dims=type_dims, - in_dims=in_dims, + embedder_dims=embedder_dims, verbose=verbose, ) @@ -461,7 +469,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): - LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] + LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder @@ -478,13 +486,14 @@ class LoRANetwork(torch.nn.Module): module_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, - module_class: Type[object] = LoRAModule, + module_class: Type[LoRAModule] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, type_dims: Optional[List[int]] = None, - in_dims: Optional[List[int]] = None, + embedder_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -501,7 +510,9 @@ class LoRANetwork(torch.nn.Module): self.split_qkv = split_qkv self.type_dims = type_dims - self.in_dims = in_dims + self.embedder_dims = embedder_dims + + self.train_block_indices = train_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -509,7 +520,7 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None: logger.info(f"create LoRA network from weights") - self.in_dims = [0] * 5 # create in_dims + self.embedder_dims = [0] * 5 # create embedder_dims # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") @@ -529,7 +540,7 @@ class LoRANetwork(torch.nn.Module): def create_modules( is_lumina: bool, root_module: torch.nn.Module, - target_replace_modules: List[str], + target_replace_modules: Optional[List[str]], filter: Optional[str] = None, default_dim: Optional[int] = None, ) -> List[LoRAModule]: @@ -544,63 +555,77 @@ class LoRANetwork(torch.nn.Module): for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - if is_linear or is_conv2d: - lora_name = prefix + "." + (name + "." if name else "") + child_name - lora_name = lora_name.replace(".", "_") + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + # Only Linear is supported + if not is_linear: + skipped.append(lora_name) + continue - dim = None - alpha = None + if filter is not None and filter not in lora_name: + continue - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = default_dim if default_dim is not None else self.lora_dim - alpha = self.alpha + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha - if is_lumina and type_dims is not None: - identifier = [ - ("attention",), # attention layers - ("mlp",), # MLP layers - ("modulation",), # modulation layers - ("refiner",), # refiner blocks - ] - for i, d in enumerate(type_dims): - if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d # may be 0 for skip - break + # Set dim/alpha to modules dim/alpha + if modules_dim is not None and modules_alpha is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + # Set dims to type_dims + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): - skipped.append(lora_name) - continue - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, + # Drop blocks if we are only training some blocks + if ( + is_lumina + and dim + and ( + self.train_block_indices is not None ) - loras.append(lora) + and ("layer" in lora_name) + ): + # "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..." + block_index = int(lora_name.split("_")[3]) # bit dirty + if ( + "layer" in lora_name + and self.train_block_indices is not None + and not self.train_block_indices[block_index] + ): + dim = 0 + + + if dim is None or dim == 0: + # skipした情報を出力 + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + logger.info(f"Add LoRA module: {lora_name}") + loras.append(lora) if target_replace_modules is None: break # all modules are searched @@ -617,15 +642,25 @@ class LoRANetwork(torch.nn.Module): skipped_te += skipped # create LoRA for U-Net - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + # TODO: limit different blocks + elif self.train_blocks == "transformer": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "refiners": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "noise_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "cap_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) # Handle embedders - if self.in_dims: - for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims): - loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim) + if self.embedder_dims: + for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims): + loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim) self.unet_loras.extend(loras) logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.") From 9fe8a470800e70a6d899dd63d09e1d63954d67fb Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 4 Mar 2025 02:28:56 -0500 Subject: [PATCH 38/73] Undo dropout after up --- networks/lora_lumina.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index f856d4e7..03d13039 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -22,10 +22,6 @@ import logging logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - - class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -142,9 +138,6 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -170,9 +163,6 @@ class LoRAModule(torch.nn.Module): lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] - if self.dropout is not None and self.training: - lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] - return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale From e8c15c716789c5b50a10190871145db2a2aad9f9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 4 Mar 2025 02:30:08 -0500 Subject: [PATCH 39/73] Remove log --- networks/lora_lumina.py | 1 - 1 file changed, 1 deletion(-) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 03d13039..15c35f44 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -614,7 +614,6 @@ class LoRANetwork(torch.nn.Module): rank_dropout=rank_dropout, module_dropout=module_dropout, ) - logger.info(f"Add LoRA module: {lora_name}") loras.append(lora) if target_replace_modules is None: From 2ba1cc7791a5438448b99d70929c6c9a54c70e73 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 21 Mar 2025 20:17:22 -0400 Subject: [PATCH 40/73] Fix max norms not applying to noise --- library/lumina_train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index f224e86c..14a79bb2 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -688,9 +688,9 @@ def denoise( noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True ) # Iterate through batch - for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred): + for i, (noise_norm, max_new_norm) in enumerate(zip(noise_norms, max_new_norms)): if noise_norm >= max_new_norm: - noise = noise * (max_new_norm / noise_norm) + noise_pred[i] = noise_pred[i] * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond From 61f7283167b2f4002b78ad4487041c10cfc2134a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 21 Mar 2025 20:38:43 -0400 Subject: [PATCH 41/73] Fix non-cache vae encode --- lumina_train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumina_train_network.py b/lumina_train_network.py index 6b7e7d22..e1b45ac7 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -230,7 +230,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) # not sure, they use same flux vae From 00e12eed657423c6e0c86a4b2134cb04aceac42c Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 6 Apr 2025 16:09:29 +0800 Subject: [PATCH 42/73] update for lost change --- library/flux_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index b00bdae2..a945a1cb 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -977,10 +977,10 @@ class Flux(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." From 7f93e21f30a0964fd6bdbe5a84d8d6af6d2f4081 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 6 Apr 2025 16:21:48 +0800 Subject: [PATCH 43/73] fix typo --- library/train_util.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 43a8a0fe..ba6e4cb9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -434,7 +434,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir @@ -500,7 +500,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -529,7 +529,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt + system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -573,7 +573,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -602,7 +602,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt + system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -642,7 +642,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -671,7 +671,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt + system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) From 899f3454b6f92b48a4d5780549edd92a6bc9db49 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 23 Apr 2025 15:47:12 +0800 Subject: [PATCH 44/73] update for init problem --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba6e4cb9..4babb8db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2174,7 +2174,8 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, - resize_interpolation: Optional[str], + system_prompt: Optional[str] = None, + resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2402,7 +2403,8 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], + system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) From 4fc917821ac014972538888f5cf59d9dd1df502b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 23 Apr 2025 16:16:36 +0800 Subject: [PATCH 45/73] fix bugs --- library/train_util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4babb8db..e2d0d175 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1883,8 +1883,8 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - system_prompt: Optional[str], - resize_interpolation: Optional[str], + system_prompt: Optional[str] = None, + resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2458,6 +2458,7 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, + system_prompt, resize_interpolation, ) From d94bed645a4d899cffd0bce5804fcf32c4500ad3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 9 Jun 2025 21:14:51 -0400 Subject: [PATCH 46/73] Add lumina tests and fix image masks --- library/lumina_models.py | 6 + library/lumina_util.py | 83 ++++--- library/sd3_train_utils.py | 259 +++------------------ tests/library/test_lumina_models.py | 295 ++++++++++++++++++++++++ tests/library/test_lumina_train_util.py | 241 +++++++++++++++++++ tests/library/test_lumina_util.py | 112 +++++++++ tests/library/test_strategy_lumina.py | 227 ++++++++++++++++++ tests/test_lumina_train_network.py | 173 ++++++++++++++ 8 files changed, 1129 insertions(+), 267 deletions(-) create mode 100644 tests/library/test_lumina_models.py create mode 100644 tests/library/test_lumina_train_util.py create mode 100644 tests/library/test_lumina_util.py create mode 100644 tests/library/test_strategy_lumina.py create mode 100644 tests/test_lumina_train_network.py diff --git a/library/lumina_models.py b/library/lumina_models.py index 2508cc7d..7e925352 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -868,6 +868,8 @@ class NextDiT(nn.Module): cap_feat_dim (int): Dimension of the caption features. axes_dims (List[int]): List of dimensions for the axes. axes_lens (List[int]): List of lengths for the axes. + use_flash_attn (bool): Whether to use Flash Attention. + use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference. Returns: None @@ -1110,7 +1112,11 @@ class NextDiT(nn.Module): cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) + for i in range(bsz): + x[i, :image_seq_len] = x[i] + x_mask[i, :image_seq_len] = True x = self.x_embedder(x) diff --git a/library/lumina_util.py b/library/lumina_util.py index 06f089d4..452b242f 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -173,62 +173,61 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: return x -DIFFUSERS_TO_ALPHA_VLLM_MAP = { + +DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = { # Embedding layers - "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], - "cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight", - "cap_embedder.1.bias": "text_embedder.1.bias", - "x_embedder.weight": "patch_embedder.proj.weight", - "x_embedder.bias": "patch_embedder.proj.bias", + "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight", + "time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight", + "text_embedder.1.bias": "cap_embedder.1.bias", + "patch_embedder.proj.weight": "x_embedder.weight", + "patch_embedder.proj.bias": "x_embedder.bias", # Attention modulation - "layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight", - "layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias", + "transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight", + "transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias", # Final layers - "final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight", - "final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias", - "final_layer.linear.weight": "final_linear.weight", - "final_layer.linear.bias": "final_linear.bias", + "final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight", + "final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias", + "final_linear.weight": "final_layer.linear.weight", + "final_linear.bias": "final_layer.linear.bias", # Noise refiner - "noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight", - "noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias", - "noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight", - "noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight", - # Time embedding - "t_embedder.mlp.0.weight": "time_embedder.0.weight", - "t_embedder.mlp.0.bias": "time_embedder.0.bias", - "t_embedder.mlp.2.weight": "time_embedder.2.weight", - "t_embedder.mlp.2.bias": "time_embedder.2.bias", - # Context attention - "context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight", - "context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight", + "single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight", + "single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias", + "single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight", + "single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight", # Normalization - "layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight", - "layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight", + "transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight", + "transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight", # FFN - "layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight", - "layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight", - "layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight", + "transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight", + "transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight", + "transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight", } def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: """Convert Diffusers checkpoint to Alpha-VLLM format""" logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") - new_sd = {} + new_sd = sd.copy() # Preserve original keys - for key, value in sd.items(): - new_key = key - for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): - if "()." in pattern: - for block_idx in range(num_double_blocks): - if str(block_idx) in key: - converted = pattern.replace("()", str(block_idx)) - new_key = key.replace(converted, replacement.replace("()", str(block_idx))) - break + for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + # Handle block-specific patterns + if '().' in diff_key: + for block_idx in range(num_double_blocks): + block_alpha_key = alpha_key.replace('().', f'{block_idx}.') + block_diff_key = diff_key.replace('().', f'{block_idx}.') + + # Search for and convert block-specific keys + for input_key, value in list(sd.items()): + if input_key == block_diff_key: + new_sd[block_alpha_key] = value + else: + # Handle static keys + if diff_key in sd: + print(f"Replacing {diff_key} with {alpha_key}") + new_sd[alpha_key] = sd[diff_key] + else: + print(f"Not found: {diff_key}") - if new_key == key: - logger.debug(f"Unmatched key in conversion: {key}") - new_sd[new_key] = value logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") return new_sd diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 6a4b39b3..c4079884 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -610,21 +610,6 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import BaseOutput -# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ @@ -664,49 +649,22 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self, num_train_timesteps: int = 1000, shift: float = 1.0, - use_dynamic_shifting=False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, ): - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None - self._shift = shift - self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() - @property - def shift(self): - """ - The value used for shifting. - """ - return self._shift - @property def step_index(self): """ @@ -732,9 +690,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index - def set_shift(self, shift: float): - self._shift = shift - def scale_noise( self, sample: torch.FloatTensor, @@ -754,31 +709,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) - - if sample.device.type == "mps" and torch.is_floating_point(timestep): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) - timestep = timestep.to(sample.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(sample.device) - timestep = timestep.to(sample.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timestep.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timestep.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(sample.shape): - sigma = sigma.unsqueeze(-1) + if self.step_index is None: + self._init_step_index(timestep) + sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -786,37 +720,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: - r""" - Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config - value. - - Reference: - https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 - - Args: - t (`torch.Tensor`): - A tensor of timesteps to be stretched and shifted. - - Returns: - `torch.Tensor`: - A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. - """ - one_minus_z = 1 - t - scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) - stretched_t = 1 - (one_minus_z / scale_factor) - return stretched_t - - def set_timesteps( - self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[float] = None, - ): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -826,49 +730,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - if self.config.use_dynamic_shifting and mu is None: - raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") - - if sigmas is None: - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) - - sigmas = timesteps / self.config.num_train_timesteps - else: - sigmas = np.array(sigmas).astype(np.float32) - num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) - else: - sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) - - if self.config.shift_terminal: - sigmas = self.stretch_shift_to_terminal(sigmas) - - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - - elif self.config.use_exponential_sigmas: - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - - elif self.config.use_beta_sigmas: - sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps - - if self.config.invert_sigmas: - sigmas = 1.0 - sigmas - timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) - else: - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.timesteps = timesteps.to(device=device) - self.sigmas = sigmas + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self._step_index = None self._begin_index = None @@ -934,11 +807,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -954,10 +823,30 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - prev_sample = sample + (sigma_next - sigma) * model_output + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -969,86 +858,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential - def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta - def _convert_to_beta( - self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 - ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.array( - [ - sigma_min + (ppf * (sigma_max - sigma_min)) - for ppf in [ - scipy.stats.beta.ppf(timestep, alpha, beta) - for timestep in 1 - np.linspace(0, 1, num_inference_steps) - ] - ] - ) - return sigmas - def __len__(self): return self.config.num_train_timesteps diff --git a/tests/library/test_lumina_models.py b/tests/library/test_lumina_models.py new file mode 100644 index 00000000..ba063688 --- /dev/null +++ b/tests/library/test_lumina_models.py @@ -0,0 +1,295 @@ +import pytest +import torch + +from library.lumina_models import ( + LuminaParams, + to_cuda, + to_cpu, + RopeEmbedder, + TimestepEmbedder, + modulate, + NextDiT, +) + +cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +def test_lumina_params(): + # Test default configuration + default_params = LuminaParams() + assert default_params.patch_size == 2 + assert default_params.in_channels == 4 + assert default_params.axes_dims == [36, 36, 36] + assert default_params.axes_lens == [300, 512, 512] + + # Test 2B config + config_2b = LuminaParams.get_2b_config() + assert config_2b.dim == 2304 + assert config_2b.in_channels == 16 + assert config_2b.n_layers == 26 + assert config_2b.n_heads == 24 + assert config_2b.cap_feat_dim == 2304 + + # Test 7B config + config_7b = LuminaParams.get_7b_config() + assert config_7b.dim == 4096 + assert config_7b.n_layers == 32 + assert config_7b.n_heads == 32 + assert config_7b.axes_dims == [64, 64, 64] + + +@cuda_required +def test_to_cuda_to_cpu(): + # Test tensor conversion + x = torch.tensor([1, 2, 3]) + x_cuda = to_cuda(x) + x_cpu = to_cpu(x_cuda) + assert x.cpu().tolist() == x_cpu.tolist() + + # Test list conversion + list_data = [torch.tensor([1]), torch.tensor([2])] + list_cuda = to_cuda(list_data) + assert all(tensor.device.type == "cuda" for tensor in list_cuda) + + list_cpu = to_cpu(list_cuda) + assert all(not tensor.device.type == "cuda" for tensor in list_cpu) + + # Test dict conversion + dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])} + dict_cuda = to_cuda(dict_data) + assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values()) + + dict_cpu = to_cpu(dict_cuda) + assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values()) + + +def test_timestep_embedder(): + # Test initialization + hidden_size = 256 + freq_emb_size = 128 + embedder = TimestepEmbedder(hidden_size, freq_emb_size) + assert embedder.frequency_embedding_size == freq_emb_size + + # Test timestep embedding + t = torch.tensor([0.5, 1.0, 2.0]) + emb_dim = freq_emb_size + embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim) + + assert embeddings.shape == (3, emb_dim) + assert embeddings.dtype == torch.float32 + + # Ensure embeddings are unique for different input times + assert not torch.allclose(embeddings[0], embeddings[1]) + + # Test forward pass + t_emb = embedder(t) + assert t_emb.shape == (3, hidden_size) + + +def test_rope_embedder_simple(): + rope_embedder = RopeEmbedder() + batch_size, seq_len = 2, 10 + + # Create position_ids with valid ranges for each axis + position_ids = torch.stack( + [ + torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511 + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511 + ], + dim=-1, + ) + + freqs_cis = rope_embedder(position_ids) + # RoPE embeddings work in pairs, so output dimension is half of total axes_dims + expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64 + assert freqs_cis.shape == (batch_size, seq_len, expected_dim) + + +def test_modulate(): + # Test modulation with different scales + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + scale = torch.tensor([1.5, 2.0]) + + modulated_x = modulate(x, scale) + + # Check that modulation scales correctly + # The function does x * (1 + scale), so: + # For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0] + expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]]) + # Which equals: [[2.5, 5.0], [9.0, 12.0]] + + assert torch.allclose(modulated_x, expected_x) + + +def test_nextdit_parameter_count_optimized(): + # The constraint is: (dim // n_heads) == sum(axes_dims) + # So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model_small = NextDiT( + patch_size=2, + in_channels=4, # Smaller + dim=120, # 120 // 4 = 30 + n_layers=2, # Much fewer layers + n_heads=4, # Fewer heads + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], # Smaller + ) + param_count_small = model_small.parameter_count() + assert param_count_small > 0 + + # For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32 + model_medium = NextDiT( + patch_size=2, + in_channels=4, + dim=192, # 192 // 6 = 32 + n_layers=4, # More layers + n_heads=6, + n_kv_heads=3, + axes_dims=[10, 11, 11], # sum = 32 + axes_lens=[10, 32, 32], + ) + param_count_medium = model_medium.parameter_count() + assert param_count_medium > param_count_small + print(f"Small model: {param_count_small:,} parameters") + print(f"Medium model: {param_count_medium:,} parameters") + + +@torch.no_grad() +def test_precompute_freqs_cis(): + # Test precompute_freqs_cis + dim = [16, 56, 56] + end = [1, 512, 512] + theta = 10000.0 + + freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta) + + # Check number of frequency tensors + assert len(freqs_cis) == len(dim) + + # Check each frequency tensor + for i, (d, e) in enumerate(zip(dim, end)): + assert freqs_cis[i].shape == (e, d // 2) + assert freqs_cis[i].dtype == torch.complex128 + + +@torch.no_grad() +def test_nextdit_patchify_and_embed(): + """Test the patchify_and_embed method which is crucial for training""" + # Create a small NextDiT model for testing + # The constraint is: (dim // n_heads) == sum(axes_dims) + # For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model = NextDiT( + patch_size=2, + in_channels=4, + dim=120, # 120 // 4 = 30 + n_layers=1, # Minimal layers for faster testing + n_refiner_layers=1, # Minimal refiner layers + n_heads=4, + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], + cap_feat_dim=120, # Match dim for consistency + ) + + # Prepare test inputs + batch_size = 2 + height, width = 64, 64 # Must be divisible by patch_size (2) + caption_seq_len = 8 + + # Create mock inputs + x = torch.randn(batch_size, 4, height, width) # Image latents + cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features + cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens + # Make second batch have shorter caption + cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch + t = torch.randn(batch_size, 120) # Timestep embeddings + + # Call patchify_and_embed + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # Validate outputs + image_seq_len = (height // 2) * (width // 2) # patch_size = 2 + expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption + max_seq_len = max(expected_seq_lengths) + + # Check joint hidden states shape + assert joint_hidden_states.shape == (batch_size, max_seq_len, 120) + assert joint_hidden_states.dtype == torch.float32 + + # Check attention mask shape and values + assert attention_mask.shape == (batch_size, max_seq_len) + assert attention_mask.dtype == torch.bool + # First batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[0, : expected_seq_lengths[0]]) + assert torch.all(~attention_mask[0, expected_seq_lengths[0] :]) + # Second batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[1, : expected_seq_lengths[1]]) + assert torch.all(~attention_mask[1, expected_seq_lengths[1] :]) + + # Check freqs_cis shape + assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2) + + # Check effective caption lengths + assert l_effective_cap_len == [caption_seq_len, 6] + + # Check sequence lengths + assert seq_lengths == expected_seq_lengths + + # Validate that the joint hidden states contain non-zero values where attention mask is True + for i in range(batch_size): + valid_positions = attention_mask[i] + # Check that valid positions have meaningful data (not all zeros) + valid_data = joint_hidden_states[i][valid_positions] + assert not torch.allclose(valid_data, torch.zeros_like(valid_data)) + + # Check that invalid positions are zeros + if valid_positions.sum() < max_seq_len: + invalid_data = joint_hidden_states[i][~valid_positions] + assert torch.allclose(invalid_data, torch.zeros_like(invalid_data)) + + +@torch.no_grad() +def test_nextdit_patchify_and_embed_edge_cases(): + """Test edge cases for patchify_and_embed""" + # Create minimal model + model = NextDiT( + patch_size=2, + in_channels=4, + dim=60, # 60 // 3 = 20 + n_layers=1, + n_refiner_layers=1, + n_heads=3, + n_kv_heads=1, + axes_dims=[8, 6, 6], # sum = 20 + axes_lens=[10, 16, 16], + cap_feat_dim=60, + ) + + # Test with empty captions (all masked) + batch_size = 1 + height, width = 32, 32 + caption_seq_len = 4 + + x = torch.randn(batch_size, 4, height, width) + cap_feats = torch.randn(batch_size, caption_seq_len, 60) + cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked + t = torch.randn(batch_size, 60) + + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # With all captions masked, effective length should be 0 + assert l_effective_cap_len == [0] + + # Sequence length should just be the image sequence length + image_seq_len = (height // 2) * (width // 2) + assert seq_lengths == [image_seq_len] + + # Joint hidden states should only contain image data + assert joint_hidden_states.shape == (batch_size, image_seq_len, 60) + assert attention_mask.shape == (batch_size, image_seq_len) + assert torch.all(attention_mask[0]) # All image positions should be valid diff --git a/tests/library/test_lumina_train_util.py b/tests/library/test_lumina_train_util.py new file mode 100644 index 00000000..bcf448c8 --- /dev/null +++ b/tests/library/test_lumina_train_util.py @@ -0,0 +1,241 @@ +import pytest +import torch +import math + +from library.lumina_train_util import ( + batchify, + time_shift, + get_lin_function, + get_schedule, + compute_density_for_timestep_sampling, + get_sigmas, + compute_loss_weighting_for_sd3, + get_noisy_model_input_and_timesteps, + apply_model_prediction_type, + retrieve_timesteps, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + + +def test_batchify(): + # Test case with no batch size specified + prompts = [ + {"prompt": "test1"}, + {"prompt": "test2"}, + {"prompt": "test3"} + ] + batchified = list(batchify(prompts)) + assert len(batchified) == 1 + assert len(batchified[0]) == 3 + + # Test case with batch size specified + batchified_sized = list(batchify(prompts, batch_size=2)) + assert len(batchified_sized) == 2 + assert len(batchified_sized[0]) == 2 + assert len(batchified_sized[1]) == 1 + + # Test batching with prompts having same parameters + prompts_with_params = [ + {"prompt": "test1", "width": 512, "height": 512}, + {"prompt": "test2", "width": 512, "height": 512}, + {"prompt": "test3", "width": 1024, "height": 1024} + ] + batchified_params = list(batchify(prompts_with_params)) + assert len(batchified_params) == 2 + + # Test invalid batch size + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=0)) + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=-1)) + + +def test_time_shift(): + # Test standard parameters + t = torch.tensor([0.5]) + mu = 1.0 + sigma = 1.0 + result = time_shift(mu, sigma, t) + assert 0 <= result <= 1 + + # Test with edge cases + t_edges = torch.tensor([0.0, 1.0]) + result_edges = time_shift(1.0, 1.0, t_edges) + + # Check that results are bounded within [0, 1] + assert torch.all(result_edges >= 0) + assert torch.all(result_edges <= 1) + + +def test_get_lin_function(): + # Default parameters + func = get_lin_function() + assert func(256) == 0.5 + assert func(4096) == 1.15 + + # Custom parameters + custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9) + assert custom_func(100) == 0.1 + assert custom_func(1000) == 0.9 + + +def test_get_schedule(): + # Basic schedule + schedule = get_schedule(num_steps=10, image_seq_len=256) + assert len(schedule) == 10 + assert all(0 <= x <= 1 for x in schedule) + + # Test different sequence lengths + short_schedule = get_schedule(num_steps=5, image_seq_len=128) + long_schedule = get_schedule(num_steps=15, image_seq_len=1024) + assert len(short_schedule) == 5 + assert len(long_schedule) == 15 + + # Test with shift disabled + unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False) + assert torch.allclose( + torch.tensor(unshifted_schedule), + torch.linspace(1, 1/10, 10) + ) + + +def test_compute_density_for_timestep_sampling(): + # Test uniform sampling + uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100) + assert len(uniform_samples) == 100 + assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1)) + + # Test logit normal sampling + logit_normal_samples = compute_density_for_timestep_sampling( + "logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0 + ) + assert len(logit_normal_samples) == 100 + assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1)) + + # Test mode sampling + mode_samples = compute_density_for_timestep_sampling( + "mode", batch_size=100, mode_scale=0.5 + ) + assert len(mode_samples) == 100 + assert torch.all((mode_samples >= 0) & (mode_samples <= 1)) + + +def test_get_sigmas(): + # Create a mock noise scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Test with default parameters + timesteps = torch.tensor([100, 500, 900]) + sigmas = get_sigmas(scheduler, timesteps, device) + + # Check shape and basic properties + assert sigmas.shape[0] == 3 + assert torch.all(sigmas >= 0) + + # Test with different n_dim + sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4) + assert sigmas_4d.ndim == 4 + + # Test with different dtype + sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16) + assert sigmas_float16.dtype == torch.float16 + + +def test_compute_loss_weighting_for_sd3(): + # Prepare some mock sigmas + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test sigma_sqrt weighting + sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas) + assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5) + + # Test cosmap weighting + cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas) + bot = 1 - 2 * sigmas + 2 * sigmas**2 + expected_cosmap = 2 / (math.pi * bot) + assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5) + + # Test default weighting + default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas) + assert torch.all(default_weighting == 1) + + +def test_apply_model_prediction_type(): + # Create mock args and tensors + class MockArgs: + model_prediction_type = "raw" + weighting_scheme = "sigma_sqrt" + + args = MockArgs() + model_pred = torch.tensor([1.0, 2.0, 3.0]) + noisy_model_input = torch.tensor([0.5, 1.0, 1.5]) + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test raw prediction type + raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(raw_pred == model_pred) + assert raw_weighting is None + + # Test additive prediction type + args.model_prediction_type = "additive" + additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(additive_pred == model_pred + noisy_model_input) + + # Test sigma scaled prediction type + args.model_prediction_type = "sigma_scaled" + sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input) + assert sigma_weighting is not None + + +def test_retrieve_timesteps(): + # Create a mock scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + + # Test with num_inference_steps + timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50) + assert len(timesteps) == 50 + assert n_steps == 50 + + # Test error handling with simultaneous timesteps and sigmas + with pytest.raises(ValueError): + retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3]) + + +def test_get_noisy_model_input_and_timesteps(): + # Create a mock args and setup + class MockArgs: + timestep_sampling = "uniform" + weighting_scheme = "sigma_sqrt" + sigmoid_scale = 1.0 + discrete_flow_shift = 6.0 + + args = MockArgs() + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Prepare mock latents and noise + latents = torch.randn(4, 16, 64, 64) + noise = torch.randn_like(latents) + + # Test uniform sampling + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + + # Validate output shapes and types + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] + assert noisy_input.dtype == torch.float32 + assert timesteps.dtype == torch.float32 + + # Test different sampling methods + sampling_methods = ["sigmoid", "shift", "nextdit_shift"] + for method in sampling_methods: + args.timestep_sampling = method + noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] diff --git a/tests/library/test_lumina_util.py b/tests/library/test_lumina_util.py new file mode 100644 index 00000000..397bab5a --- /dev/null +++ b/tests/library/test_lumina_util.py @@ -0,0 +1,112 @@ +import torch +from torch.nn.modules import conv + +from library import lumina_util + + +def test_unpack_latents(): + # Create a test tensor + # Shape: [batch, height*width, channels*patch_height*patch_width] + x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels + packed_latent_height = 2 + packed_latent_width = 2 + + # Unpack the latents + unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + + # Check output shape + # Expected shape: [batch, channels, height*patch_height, width*patch_width] + assert unpacked.shape == (2, 4, 4, 4) + + +def test_pack_latents(): + # Create a test tensor + # Shape: [batch, channels, height*patch_height, width*patch_width] + x = torch.randn(2, 4, 4, 4) + + # Pack the latents + packed = lumina_util.pack_latents(x) + + # Check output shape + # Expected shape: [batch, height*width, channels*patch_height*patch_width] + assert packed.shape == (2, 4, 16) + + +def test_convert_diffusers_sd_to_alpha_vllm(): + num_double_blocks = 2 + # Predefined test cases based on the actual conversion map + test_cases = [ + # Static key conversions with possible list mappings + { + "original_keys": ["time_caption_embed.caption_embedder.0.weight"], + "original_pattern": ["time_caption_embed.caption_embedder.0.weight"], + "expected_converted_keys": ["cap_embedder.0.weight"], + }, + { + "original_keys": ["patch_embedder.proj.weight"], + "original_pattern": ["patch_embedder.proj.weight"], + "expected_converted_keys": ["x_embedder.weight"], + }, + { + "original_keys": ["transformer_blocks.0.norm1.weight"], + "original_pattern": ["transformer_blocks.().norm1.weight"], + "expected_converted_keys": ["layers.0.attention_norm1.weight"], + }, + ] + + + for test_case in test_cases: + for original_key, original_pattern, expected_converted_key in zip( + test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"] + ): + # Create test state dict + test_sd = {original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion (handle both string and list keys) + # Find the correct converted key + match_found = False + if expected_converted_key in converted_sd: + # Verify tensor preservation + assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), ( + f"Tensor mismatch for {original_key}" + ) + match_found = True + break + + assert match_found, f"Failed to convert {original_key}" + + # Ensure original key is also present + assert original_key in converted_sd + + # Test with block-specific keys + block_specific_cases = [ + { + "original_pattern": "transformer_blocks.().norm1.weight", + "converted_pattern": "layers.().attention_norm1.weight", + } + ] + + for case in block_specific_cases: + for block_idx in range(2): # Test multiple block indices + # Prepare block-specific keys + block_original_key = case["original_pattern"].replace("()", str(block_idx)) + block_converted_key = case["converted_pattern"].replace("()", str(block_idx)) + print(block_original_key, block_converted_key) + + # Create test state dict + test_sd = {block_original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion + # assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}" + assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), ( + f"Tensor mismatch for block key {block_original_key}" + ) + + # Ensure original key is also present + assert block_original_key in converted_sd diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py new file mode 100644 index 00000000..18e196bf --- /dev/null +++ b/tests/library/test_strategy_lumina.py @@ -0,0 +1,227 @@ +import os +import tempfile +import torch +import numpy as np +from unittest.mock import patch +from transformers import Gemma2Model + +from library.strategy_lumina import ( + LuminaTokenizeStrategy, + LuminaTextEncodingStrategy, + LuminaTextEncoderOutputsCachingStrategy, + LuminaLatentsCachingStrategy, +) + + +class SimpleMockGemma2Model: + """Lightweight mock that avoids initializing the actual Gemma2Model""" + + def __init__(self, hidden_size=2304): + self.device = torch.device("cpu") + self._hidden_size = hidden_size + self._orig_mod = self # For dynamic compilation compatibility + + def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False): + # Create a mock output object with hidden states + batch_size, seq_len = input_ids.shape + hidden_size = self._hidden_size + + class MockOutput: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + + mock_hidden_states = [ + torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device) + for _ in range(3) # Mimic multiple layers of hidden states + ] + + return MockOutput(mock_hidden_states) + + +def test_lumina_tokenize_strategy(): + # Test default initialization + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + assert tokenize_strategy.max_length == 256 + assert tokenize_strategy.tokenizer.padding_side == "right" + + # Test tokenization of a single string + text = "Hello" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + assert tokens.ndim == 2 + assert attention_mask.ndim == 2 + assert tokens.shape == attention_mask.shape + assert tokens.shape[1] == 256 # max_length + + # Test tokenize_with_weights + tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text) + assert len(weights) == 1 + assert torch.all(weights[0] == 1) + + +def test_lumina_text_encoding_strategy(): + # Create strategies + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + encoding_strategy = LuminaTextEncodingStrategy() + + # Create a mock model + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Prepare sample text + text = "Test encoding strategy" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + # Perform encoding + hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens( + tokenize_strategy, [mock_model], (tokens, attention_mask) + ) + + # Validate outputs + assert original_isinstance(hidden_states, torch.Tensor) + assert original_isinstance(input_ids, torch.Tensor) + assert original_isinstance(attention_masks, torch.Tensor) + + # Check the shape of the second-to-last hidden state + assert hidden_states.ndim == 3 + + # Test weighted encoding (which falls back to standard encoding for Lumina) + weights = [torch.ones_like(tokens)] + hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [mock_model], (tokens, attention_mask), weights + ) + + # For the mock, we can't guarantee identical outputs since each call returns random tensors + # Instead, check that the outputs have the same shape and are tensors + assert hidden_states_w.shape == hidden_states.shape + assert original_isinstance(hidden_states_w, torch.Tensor) + assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same + assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same + + +def test_lumina_text_encoder_outputs_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Create a cache file path + cache_file = os.path.join(tmpdir, "test_outputs.npz") + + # Create the caching strategy + caching_strategy = LuminaTextEncoderOutputsCachingStrategy( + cache_to_disk=True, + batch_size=1, + skip_disk_cache_validity_check=False, + ) + + # Create a mock class for ImageInfo + class MockImageInfo: + def __init__(self, caption, system_prompt, cache_path): + self.caption = caption + self.system_prompt = system_prompt + self.text_encoder_outputs_npz = cache_path + + # Create a sample input info + image_info = MockImageInfo("Test caption", "", cache_file) + + # Simulate a batch + batch = [image_info] + + # Create mock strategies and model + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + encoding_strategy = LuminaTextEncodingStrategy() + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Call cache_batch_outputs + caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch) + + # Verify the npz file was created + assert os.path.exists(cache_file), f"Cache file not created at {cache_file}" + + # Verify the is_disk_cached_outputs_expected method + assert caching_strategy.is_disk_cached_outputs_expected(cache_file) + + # Test loading from npz + loaded_data = caching_strategy.load_outputs_npz(cache_file) + assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask + + +def test_lumina_latents_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Prepare a mock absolute path + abs_path = os.path.join(tmpdir, "test_image.png") + + # Use smaller image size for faster testing + image_size = (64, 64) + + # Create a smaller dummy image for testing + test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + # Create the caching strategy + caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False) + + # Create a simple mock VAE + class MockVAE: + def __init__(self): + self.device = torch.device("cpu") + self.dtype = torch.float32 + + def encode(self, x): + # Return smaller encoded tensor for faster processing + encoded = torch.randn(1, 4, 8, 8, device=x.device) + return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded}) + + # Prepare a mock batch + class MockImageInfo: + def __init__(self, path, image): + self.absolute_path = path + self.image = image + self.image_path = path + self.bucket_reso = image_size + self.resized_size = image_size + self.resize_interpolation = "lanczos" + # Specify full path to the latents npz file + self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz") + + batch = [MockImageInfo(abs_path, test_image)] + + # Call cache_batch_latents + mock_vae = MockVAE() + caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False) + + # Generate the expected npz path + npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size) + + # Verify the file was created + assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}" + + # Verify is_disk_cached_latents_expected + assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False) + + # Test loading from disk + loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size) + assert len(loaded_data) == 5 # Check for 5 expected elements diff --git a/tests/test_lumina_train_network.py b/tests/test_lumina_train_network.py new file mode 100644 index 00000000..353a742f --- /dev/null +++ b/tests/test_lumina_train_network.py @@ -0,0 +1,173 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +import argparse + +from library import lumina_models, lumina_util +from lumina_train_network import LuminaNetworkTrainer + + +@pytest.fixture +def lumina_trainer(): + return LuminaNetworkTrainer() + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.pretrained_model_name_or_path = "test_path" + args.disable_mmap_load_safetensors = False + args.use_flash_attn = False + args.use_sage_attn = False + args.fp8_base = False + args.blocks_to_swap = None + args.gemma2 = "test_gemma2_path" + args.ae = "test_ae_path" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = False + args.network_train_unet_only = False + return args + + +@pytest.fixture +def mock_accelerator(): + accelerator = MagicMock() + accelerator.device = torch.device("cpu") + accelerator.prepare.side_effect = lambda x, **kwargs: x + accelerator.unwrap_model.side_effect = lambda x: x + return accelerator + + +def test_assert_extra_args(lumina_trainer, mock_args): + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + + # Test with default settings + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # Verify verify_bucket_reso_steps was called for both groups + assert train_dataset_group.verify_bucket_reso_steps.call_count > 0 + assert val_dataset_group.verify_bucket_reso_steps.call_count > 0 + + # Check text encoder output caching + assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only) + assert mock_args.cache_text_encoder_outputs is True + + +def test_load_target_model(lumina_trainer, mock_args, mock_accelerator): + # Patch lumina_util methods + with ( + patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model, + patch("library.lumina_util.load_gemma2") as mock_load_gemma2, + patch("library.lumina_util.load_ae") as mock_load_ae + ): + # Create mock models + mock_model = MagicMock(spec=lumina_models.NextDiT) + mock_model.dtype = torch.float32 + mock_gemma2 = MagicMock() + mock_ae = MagicMock() + + mock_load_lumina_model.return_value = mock_model + mock_load_gemma2.return_value = mock_gemma2 + mock_load_ae.return_value = mock_ae + + # Test load_target_model + version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator) + + # Verify calls and return values + assert version == lumina_util.MODEL_VERSION_LUMINA_V2 + assert gemma2_list == [mock_gemma2] + assert ae == mock_ae + assert model == mock_model + + # Verify load calls + mock_load_lumina_model.assert_called_once() + mock_load_gemma2.assert_called_once() + mock_load_ae.assert_called_once() + + +def test_get_strategies(lumina_trainer, mock_args): + # Test tokenize strategy + tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args) + assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy" + + # Test latents caching strategy + latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args) + assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy" + + # Test text encoding strategy + text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args) + assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy" + + +def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args): + # Call assert_extra_args to set train_gemma2 + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # With text encoder caching enabled + mock_args.skip_cache_check = False + mock_args.text_encoder_batch_size = 16 + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + + assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy" + assert strategy.cache_to_disk is False # based on mock_args + + # With text encoder caching disabled + mock_args.cache_text_encoder_outputs = False + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + assert strategy is None + + +def test_noise_scheduler(lumina_trainer, mock_args): + device = torch.device("cpu") + noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device) + + assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler" + assert noise_scheduler.num_train_timesteps == 1000 + assert hasattr(lumina_trainer, "noise_scheduler_copy") + + +def test_sai_model_spec(lumina_trainer, mock_args): + with patch("library.train_util.get_sai_model_spec") as mock_get_spec: + mock_get_spec.return_value = "test_spec" + spec = lumina_trainer.get_sai_model_spec(mock_args) + assert spec == "test_spec" + mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2") + + +def test_update_metadata(lumina_trainer, mock_args): + metadata = {} + lumina_trainer.update_metadata(metadata, mock_args) + + assert "ss_weighting_scheme" in metadata + assert "ss_logit_mean" in metadata + assert "ss_logit_std" in metadata + assert "ss_mode_scale" in metadata + assert "ss_timestep_sampling" in metadata + assert "ss_sigmoid_scale" in metadata + assert "ss_model_prediction_type" in metadata + assert "ss_discrete_flow_shift" in metadata + + +def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args): + # Test with text encoder output caching, but not training text encoder + mock_args.cache_text_encoder_outputs = True + with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=False): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is True + + # Test with text encoder output caching and training text encoder + with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=True): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False + + # Test with no text encoder output caching + mock_args.cache_text_encoder_outputs = False + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False \ No newline at end of file From 0e929f97b9dfc488a454d62a3e27696c167a3936 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 16 Jun 2025 16:50:18 -0400 Subject: [PATCH 47/73] Revert system_prompt for dataset config --- library/train_util.py | 74 +++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 48 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 68019e21..1d80bcd8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -192,7 +192,7 @@ class ImageInfo: self.latents_flipped: Optional[torch.Tensor] = None self.latents_npz: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = ( + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( None # crop left top right bottom in original pixel size, not latents size ) self.cond_img_path: Optional[str] = None @@ -209,8 +209,6 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.resize_interpolation: Optional[str] = None - self.system_prompt: Optional[str] = None - class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -434,7 +432,6 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir @@ -466,7 +463,6 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split - self.system_prompt = system_prompt self.resize_interpolation = resize_interpolation @@ -500,7 +496,6 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -529,15 +524,14 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension - # if self.caption_extension and not self.caption_extension.startswith("."): - # self.caption_extension = "." + self.caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension self.cache_info = cache_info def __eq__(self, other) -> bool: @@ -573,7 +567,6 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -602,7 +595,6 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -642,7 +634,6 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -671,7 +662,6 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -1713,10 +1703,8 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - system_prompt_special_token = "" - system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else "" caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension # if self.XTI_layers: # caption_layer = [] # for layer in self.XTI_layers: @@ -1886,8 +1874,7 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - system_prompt: Optional[str] = None, - resize_interpolation: Optional[str] = None, + resize_interpolation: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -1900,7 +1887,6 @@ class DreamBoothDataset(BaseDataset): self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split - self.system_prompt = system_prompt self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1917,33 +1903,30 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name tokens = base_name.split("_") if len(tokens) >= 5: base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)] + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] caption = None - for base, cap_extension in cap_paths: - # check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt) - for cap_path in [base + cap_extension, base + "." + cap_extension]: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - if enable_wildcard: - caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 - else: - caption = lines[0].strip() - break - break + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() + break return caption def load_dreambooth_dir(subset: DreamBoothSubset): @@ -2090,7 +2073,6 @@ class DreamBoothDataset(BaseDataset): num_train_images = 0 num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] - for subset in subsets: num_repeats = subset.num_repeats if self.is_training_dataset else 1 if num_repeats < 1: @@ -2117,10 +2099,8 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += num_repeats * len(img_paths) - system_prompt_special_token = "" - system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else "" for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation if size is not None: info.image_size = size @@ -2177,8 +2157,7 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, - system_prompt: Optional[str] = None, - resize_interpolation: Optional[str] = None, + resize_interpolation: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2406,8 +2385,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], - system_prompt: Optional[str] = None, + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2461,7 +2439,6 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, - system_prompt, resize_interpolation, ) @@ -3005,7 +2982,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size @@ -6241,6 +6218,7 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["renorm_cfg"] = float(m.group(1)) continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) From 935e0037dc7d520f87e2d05dd0a306bfe26c60bc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jun 2025 21:33:09 +0900 Subject: [PATCH 48/73] feat: update lumina system prompt handling --- .gitignore | 1 + library/config_util.py | 6 ------ library/strategy_lumina.py | 3 +-- lumina_train.py | 4 +++- lumina_train_network.py | 9 ++++----- tests/library/test_strategy_lumina.py | 5 ++--- 6 files changed, 11 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index e492b1ad..4fcf07f6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ venv build .vscode wandb +MagicMock \ No newline at end of file diff --git a/library/config_util.py b/library/config_util.py index ac726e4f..53727f25 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,7 +75,6 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 - system_prompt: Optional[str] = None resize_interpolation: Optional[str] = None @@ -108,7 +107,6 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 - system_prompt: Optional[str] = None resize_interpolation: Optional[str] = None @dataclass @@ -199,7 +197,6 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, - "system_prompt": str, "resize_interpolation": str, } # DO means DropOut @@ -246,7 +243,6 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, - "system_prompt": str, "resize_interpolation": str, } @@ -534,7 +530,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu resolution: {(dataset.width, dataset.height)} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} - system_prompt: {dataset.system_prompt} """) if dataset.enable_bucket: @@ -569,7 +564,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu alpha_mask: {subset.alpha_mask} resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} - system_prompt: {subset.system_prompt} """), " ") if is_dreambooth: diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index d9e93f53..3d86dbef 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -218,8 +218,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) - system_prompt_special_token = "" - captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch] + captions = [info.caption for info in batch] if self.is_weighted: tokens, attention_masks, weights_list = ( diff --git a/lumina_train.py b/lumina_train.py index 330d0093..4b733c9e 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -266,12 +266,14 @@ def train(args): strategy_base.TextEncodingStrategy.get_strategy() ) + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: for p in [ - prompt_dict.get("prompt", ""), + system_prompt + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ]: if p not in sample_prompts_te_outputs: diff --git a/lumina_train_network.py b/lumina_train_network.py index e1b45ac7..037ddac6 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -58,7 +58,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, - use_sage_attn=args.use_sage_attn + use_sage_attn=args.use_sage_attn, ) if args.fp8_base: @@ -75,7 +75,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): model.to(torch.float8_e4m3fn) if args.blocks_to_swap: - logger.info(f'Lumina 2: Enabling block swap: {args.blocks_to_swap}') + logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}") model.enable_block_swap(args.blocks_to_swap, accelerator.device) self.is_swapping_blocks = True @@ -157,13 +157,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: prompts = [ - prompt_dict.get("prompt", ""), + system_prompt + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] for i, prompt in enumerate(prompts): @@ -371,7 +371,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): accelerator.unwrap_model(unet).prepare_block_swap_before_forward() - def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py index 18e196bf..aca16347 100644 --- a/tests/library/test_strategy_lumina.py +++ b/tests/library/test_strategy_lumina.py @@ -126,13 +126,12 @@ def test_lumina_text_encoder_outputs_caching_strategy(): # Create a mock class for ImageInfo class MockImageInfo: - def __init__(self, caption, system_prompt, cache_path): + def __init__(self, caption, cache_path): self.caption = caption - self.system_prompt = system_prompt self.text_encoder_outputs_npz = cache_path # Create a sample input info - image_info = MockImageInfo("Test caption", "", cache_file) + image_info = MockImageInfo("Test caption", cache_file) # Simulate a batch batch = [image_info] From 884c1f37c4c16fa83ed14f46f6e209770fbed4d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jun 2025 21:58:43 +0900 Subject: [PATCH 49/73] fix: update to work with cache text encoder outputs (without disk) --- library/strategy_lumina.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 3d86dbef..392d6594 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -264,8 +264,8 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) else: info.text_encoder_outputs = [ hidden_state_i, - attention_mask_i, input_ids_i, + attention_mask_i, ] From 5034c6f813a39c1db9c2b0a5f8140f6364ca984d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jun 2025 22:00:58 +0900 Subject: [PATCH 50/73] feat: add workaround for 'gated repo' error on github actions --- tests/library/test_strategy_lumina.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py index aca16347..9bb0edf7 100644 --- a/tests/library/test_strategy_lumina.py +++ b/tests/library/test_strategy_lumina.py @@ -40,7 +40,12 @@ class SimpleMockGemma2Model: def test_lumina_tokenize_strategy(): # Test default initialization - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + try: + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return assert tokenize_strategy.max_length == 256 assert tokenize_strategy.tokenizer.padding_side == "right" @@ -61,7 +66,12 @@ def test_lumina_tokenize_strategy(): def test_lumina_text_encoding_strategy(): # Create strategies - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + try: + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return encoding_strategy = LuminaTextEncodingStrategy() # Create a mock model @@ -137,7 +147,12 @@ def test_lumina_text_encoder_outputs_caching_strategy(): batch = [image_info] # Create mock strategies and model - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + try: + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return encoding_strategy = LuminaTextEncodingStrategy() mock_model = SimpleMockGemma2Model() From 078ee28a949b65d16ade97824d8273bd8bbd6598 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jun 2025 22:06:19 +0900 Subject: [PATCH 51/73] feat: add more workaround for 'gated repo' error on github actions --- tests/test_lumina_train_network.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/test_lumina_train_network.py b/tests/test_lumina_train_network.py index 353a742f..2b8fe21d 100644 --- a/tests/test_lumina_train_network.py +++ b/tests/test_lumina_train_network.py @@ -61,7 +61,7 @@ def test_load_target_model(lumina_trainer, mock_args, mock_accelerator): with ( patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model, patch("library.lumina_util.load_gemma2") as mock_load_gemma2, - patch("library.lumina_util.load_ae") as mock_load_ae + patch("library.lumina_util.load_ae") as mock_load_ae, ): # Create mock models mock_model = MagicMock(spec=lumina_models.NextDiT) @@ -90,8 +90,12 @@ def test_load_target_model(lumina_trainer, mock_args, mock_accelerator): def test_get_strategies(lumina_trainer, mock_args): # Test tokenize strategy - tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args) - assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy" + try: + tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args) + assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy" + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") # Test latents caching strategy latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args) @@ -114,10 +118,10 @@ def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args): mock_args.skip_cache_check = False mock_args.text_encoder_batch_size = 16 strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) - + assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy" assert strategy.cache_to_disk is False # based on mock_args - + # With text encoder caching disabled mock_args.cache_text_encoder_outputs = False strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) @@ -158,16 +162,16 @@ def test_update_metadata(lumina_trainer, mock_args): def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args): # Test with text encoder output caching, but not training text encoder mock_args.cache_text_encoder_outputs = True - with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=False): + with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False): result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) assert result is True # Test with text encoder output caching and training text encoder - with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=True): + with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True): result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) assert result is False # Test with no text encoder output caching mock_args.cache_text_encoder_outputs = False result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) - assert result is False \ No newline at end of file + assert result is False From 6731d8a57fb9a31c37dfaf926c5d70af0dc69b24 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Jun 2025 22:21:48 +0900 Subject: [PATCH 52/73] fix: update system prompt handling --- library/strategy_lumina.py | 16 ++++++++++++++-- lumina_train.py | 14 ++++++-------- lumina_train_network.py | 11 +++-------- tests/library/test_strategy_lumina.py | 6 +++--- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 392d6594..964d9f7a 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -25,20 +25,26 @@ GEMMA_ID = "google/gemma-2-2b" class LuminaTokenizeStrategy(TokenizeStrategy): def __init__( - self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None ) -> None: self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( GEMMA_ID, cache_dir=tokenizer_cache_dir ) self.tokenizer.padding_side = "right" + if system_prompt is None: + system_prompt = "" + system_prompt_special_token = "" + system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else "" + self.system_prompt = system_prompt + if max_length is None: self.max_length = 256 else: self.max_length = max_length def tokenize( - self, text: Union[str, List[str]] + self, text: Union[str, List[str]], is_negative: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -49,6 +55,12 @@ class LuminaTokenizeStrategy(TokenizeStrategy): token input ids, attention_masks """ text = [text] if isinstance(text, str) else text + + # In training, we always add system prompt (is_negative=False) + if not is_negative: + # Add system prompt to the beginning of each text + text = [self.system_prompt + t for t in text] + encodings = self.tokenizer( text, max_length=self.max_length, diff --git a/lumina_train.py b/lumina_train.py index 4b733c9e..0a91f4a0 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -166,7 +166,7 @@ def train(args): ) ) strategy_base.TokenizeStrategy.set_strategy( - strategy_lumina.LuminaTokenizeStrategy() + strategy_lumina.LuminaTokenizeStrategy(args.system_prompt) ) train_dataset_group.set_current_strategies() @@ -221,7 +221,7 @@ def train(args): gemma2_max_token_length = args.gemma2_max_token_length lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy( - gemma2_max_token_length + args.system_prompt, gemma2_max_token_length ) strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy) @@ -266,19 +266,17 @@ def train(args): strategy_base.TextEncodingStrategy.get_strategy() ) - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: - for p in [ - system_prompt + prompt_dict.get("prompt", ""), + for i, p in enumerate([ + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), - ]: + ]): if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = lumina_tokenize_strategy.tokenize(p) + tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt sample_prompts_te_outputs[p] = ( text_encoding_strategy.encode_tokens( lumina_tokenize_strategy, diff --git a/lumina_train_network.py b/lumina_train_network.py index 037ddac6..b08e3143 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -86,7 +86,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model def get_tokenize_strategy(self, args): - return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir) + return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): return [tokenize_strategy.tokenizer] @@ -156,25 +156,20 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: prompts = [ - system_prompt + prompt_dict.get("prompt", ""), + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] for i, prompt in enumerate(prompts): - # Add system prompt only to positive prompt - if i == 0: - prompt = system_prompt + prompt if prompt in sample_prompts_te_outputs: continue logger.info(f"cache Text Encoder outputs for prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) + tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( tokenize_strategy, text_encoders, diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py index 9bb0edf7..d77d2738 100644 --- a/tests/library/test_strategy_lumina.py +++ b/tests/library/test_strategy_lumina.py @@ -41,7 +41,7 @@ class SimpleMockGemma2Model: def test_lumina_tokenize_strategy(): # Test default initialization try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") @@ -67,7 +67,7 @@ def test_lumina_tokenize_strategy(): def test_lumina_text_encoding_strategy(): # Create strategies try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") @@ -148,7 +148,7 @@ def test_lumina_text_encoder_outputs_caching_strategy(): # Create mock strategies and model try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") From 05f392fa27371291b26c0ca5b751a3b829cd52d2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 3 Jul 2025 21:47:15 +0900 Subject: [PATCH 53/73] feat: add minimum inference code for Lumina with image generation capabilities --- lumina_minimal_inference.py | 295 ++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 lumina_minimal_inference.py diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py new file mode 100644 index 00000000..ff7c21df --- /dev/null +++ b/lumina_minimal_inference.py @@ -0,0 +1,295 @@ +# Minimum Inference Code for Lumina +# Based on flux_minimal_inference.py + +import logging +import argparse +import math +import os +import random +import time +from typing import Optional + +import einops +import numpy as np +import torch +from accelerate import Accelerator +from PIL import Image +from safetensors.torch import load_file +from tqdm import tqdm +from transformers import Gemma2Model +from library.flux_models import AutoEncoder + +from library import ( + device_utils, + lumina_models, + lumina_train_util, + lumina_util, + sd3_train_utils, + strategy_lumina, +) +from library.device_utils import get_preferred_device, init_ipex +from library.utils import setup_logging, str_to_dtype + +init_ipex() +setup_logging() +logger = logging.getLogger(__name__) + + +def generate_image( + model: lumina_models.NextDiT, + gemma2: Gemma2Model, + ae: AutoEncoder, + prompt: str, + system_prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: int, + guidance_scale: float, + negative_prompt: Optional[str], + args, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, +): + # + # 0. Prepare arguments + # + device = get_preferred_device() + if args.device: + device = torch.device(args.device) + + dtype = str_to_dtype(args.dtype) + ae_dtype = str_to_dtype(args.ae_dtype) + gemma2_dtype = str_to_dtype(args.gemma2_dtype) + + # + # 1. Prepare models + # + # model.to(device, dtype=dtype) + model.to(dtype) + model.eval() + + gemma2.to(device, dtype=gemma2_dtype) + gemma2.eval() + + ae.to(ae_dtype) + ae.eval() + + # + # 2. Encode prompts + # + logger.info("Encoding prompts...") + + tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length) + encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy() + + tokens_and_masks = tokenize_strategy.tokenize(prompt) + with torch.no_grad(): + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) + with torch.no_grad(): + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) + + # Unpack Gemma2 outputs + prompt_hidden_states, _, prompt_attention_mask = gemma2_conds + uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds + + if args.offload: + print("Offloading models to CPU to save VRAM...") + gemma2.to("cpu") + device_utils.clean_memory() + + model.to(device) + + # + # 3. Prepare latents + # + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + torch.manual_seed(seed) + + latent_height = image_height // 8 + latent_width = image_width // 8 + latent_channels = 16 + + latents = torch.randn( + (1, latent_channels, latent_height, latent_width), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # + # 4. Denoise + # + logger.info("Denoising...") + scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + # # compare with lumina_train_util.retrieve_timesteps + # lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps) + # print(f"Using timesteps: {timesteps}") + # print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same + + with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad(): + latents = lumina_train_util.denoise( + scheduler, + model, + latents.to(device), + prompt_hidden_states.to(device), + prompt_attention_mask.to(device), + uncond_hidden_states.to(device), + uncond_attention_mask.to(device), + timesteps, + guidance_scale, + cfg_trunc_ratio, + renorm_cfg, + ) + + if args.offload: + model.to("cpu") + device_utils.clean_memory() + ae.to(device) + + # + # 5. Decode latents + # + logger.info("Decoding image...") + latents = latents / ae.scale_factor + ae.shift_factor + with torch.no_grad(): + image = ae.decode(latents.to(ae_dtype)) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + image = (image * 255).round().astype("uint8") + + # + # 6. Save image + # + pil_image = Image.fromarray(image[0]) + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + seed_suffix = f"_{seed}" + output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png") + pil_image.save(output_path) + logger.info(f"Image saved to {output_path}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Lumina DiT model path / Lumina DiTモデルのパス", + ) + parser.add_argument( + "--gemma2_path", + type=str, + default=None, + required=True, + help="Gemma2 model path / Gemma2モデルのパス", + ) + parser.add_argument( + "--ae_path", + type=str, + default=None, + required=True, + help="Autoencoder model path / Autoencoderモデルのパス", + ) + parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation") + parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty") + parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--steps", type=int, default=30, help="Number of inference steps") + parser.add_argument("--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier-free guidance") + parser.add_argument("--image_width", type=int, default=1024, help="Image width") + parser.add_argument("--image_height", type=int, default=1024, help="Image height") + parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)") + parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)") + parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)") + parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')") + parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM") + parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model") + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=256, + help="Max token length for Gemma2 tokenizer", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=1.0, + help="Shift value for FlowMatchEulerDiscreteScheduler", + ) + parser.add_argument( + "--cfg_trunc_ratio", + type=float, + default=0.25, + help="TBD", + ) + parser.add_argument( + "--renorm_cfg", + type=float, + default=1.0, + help="TBD", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use flash attention for Lumina model", + ) + parser.add_argument( + "--use_sage_attn", + action="store_true", + help="Use sage attention for Lumina model", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + + logger.info("Loading models...") + device = get_preferred_device() + if args.device: + device = torch.device(args.device) + + # Load Lumina DiT model + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + dtype=None, # Load in fp32 and then convert + device="cpu", + use_flash_attn=args.use_flash_attn, + use_sage_attn=args.use_sage_attn, + ) + + # Load Gemma2 + gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu") + + # Load Autoencoder + ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu") + + generate_image( + model, + gemma2, + ae, + args.prompt, + args.system_prompt, + args.seed, + args.image_width, + args.image_height, + args.steps, + args.guidance_scale, + args.negative_prompt, + args, + args.cfg_trunc_ratio, + args.renorm_cfg, + ) + + logger.info("Done.") From a87e9997861c58df7148705be12dae17114615de Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 7 Jul 2025 17:12:07 -0400 Subject: [PATCH 54/73] Change to 3 --- networks/lora_lumina.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 15c35f44..e4149b4a 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -344,7 +344,7 @@ def create_network( if embedder_dims.startswith("[") and embedder_dims.endswith("]"): embedder_dims = embedder_dims[1:-1] embedder_dims = [int(d) for d in embedder_dims.split(",")] - assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder)" + assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 3 dimensions (x_embedder, t_embedder, cap_embedder)" # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) From b4d11522939ce65aef46d835c00969a25bb485c5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 9 Jul 2025 21:55:36 +0900 Subject: [PATCH 55/73] fix: sample generation with system prompt, without TE output caching --- library/lumina_train_util.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 14a79bb2..45f22bc4 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -249,7 +249,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, nextdit: lumina_models.NextDiT, - gemma2_model: Gemma2Model, + gemma2_model: list[Gemma2Model], vae: AutoEncoder, save_dir: str, prompt_dicts: list[Dict[str, str]], @@ -266,7 +266,7 @@ def sample_image_inference( accelerator (Accelerator): Accelerator object args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model - gemma2_model (Gemma2Model): Gemma2 model + gemma2_model (list[Gemma2Model]): Gemma2 model vae (AutoEncoder): VAE model save_dir (str): Directory to save images prompt_dict (Dict[str, str]): Prompt dictionary @@ -330,12 +330,8 @@ def sample_image_inference( logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" - # Apply system prompt to prompts - prompt = system_prompt + prompt - negative_prompt = negative_prompt + # No need to add system prompt here, as it has been handled in the tokenize_strategy # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: @@ -355,12 +351,12 @@ def sample_image_inference( if gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(prompt) gemma2_conds = encoding_strategy.encode_tokens( - tokenize_strategy, [gemma2_model], tokens_and_masks + tokenize_strategy, gemma2_model, tokens_and_masks ) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) neg_gemma2_conds = encoding_strategy.encode_tokens( - tokenize_strategy, [gemma2_model], tokens_and_masks + tokenize_strategy, gemma2_model, tokens_and_masks ) # Unpack Gemma2 outputs From 7fb0d30feba5f1112ad28099ac79b6109e98ec57 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 9 Jul 2025 23:28:55 +0900 Subject: [PATCH 56/73] feat: add LoRA support for lumina minimal inference --- lumina_minimal_inference.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index ff7c21df..ba305f6f 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -27,6 +27,7 @@ from library import ( sd3_train_utils, strategy_lumina, ) +import networks.lora_lumina as lora_lumina from library.device_utils import get_preferred_device, init_ipex from library.utils import setup_logging, str_to_dtype @@ -248,6 +249,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="Use sage attention for Lumina model", ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") return parser @@ -275,6 +284,30 @@ if __name__ == "__main__": # Load Autoencoder ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu") + # LoRA + lora_models = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + lora_model, _ = lora_lumina.create_network_from_weights( + multiplier, None, ae, [gemma2], model, weights_sd, True + ) + + if args.merge_lora_weights: + lora_model.merge_to([gemma2], model, weights_sd) + else: + lora_model.apply_to([gemma2], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + + lora_models.append(lora_model) + generate_image( model, gemma2, From 3f9eab49467ba2d224d48464aac11cb07b85dbb1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 9 Jul 2025 23:33:50 +0900 Subject: [PATCH 57/73] fix: update default values in lumina minimal inference as same as sample generation --- lumina_minimal_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index ba305f6f..4f915179 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -205,8 +205,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty") parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images") parser.add_argument("--seed", type=int, default=None, help="Random seed") - parser.add_argument("--steps", type=int, default=30, help="Number of inference steps") - parser.add_argument("--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier-free guidance") + parser.add_argument("--steps", type=int, default=36, help="Number of inference steps") + parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance") parser.add_argument("--image_width", type=int, default=1024, help="Image width") parser.add_argument("--image_height", type=int, default=1024, help="Image height") parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)") @@ -224,7 +224,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--discrete_flow_shift", type=float, - default=1.0, + default=6.0, help="Shift value for FlowMatchEulerDiscreteScheduler", ) parser.add_argument( From d0b335d8cf543da68963103cbd7ae8d630d1eb3a Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 10 Jul 2025 20:15:45 +0900 Subject: [PATCH 58/73] feat: add LoRA training guide for Lumina Image 2.0 (WIP) --- docs/lumina_train_network.md | 311 +++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 docs/lumina_train_network.md diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md new file mode 100644 index 00000000..1c3794ab --- /dev/null +++ b/docs/lumina_train_network.md @@ -0,0 +1,311 @@ +Status: reviewed + +# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド + +This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository. + +## 1. Introduction / はじめに + +`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE). + +This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md). + +**Prerequisites:** + +* The `sd-scripts` repository has been cloned and the Python environment is ready. +* A training dataset has been prepared. See the [Dataset Configuration Guide](link/to/dataset/config/doc). +* Lumina Image 2.0 model files for training are available. + +
+日本語 +ステータス:内容を一通り確認した + +`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) +* 学習対象のLumina Image 2.0モデルファイルが準備できていること。 +
+ +## 2. Differences from `train_network.py` / `train_network.py` との違い + +`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are: + +* **Target models:** Lumina Image 2.0 models. +* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX. +* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately. +* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. +* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt. + +
+日本語 +`lumina_train_network.py`は`train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** Lumina Image 2.0モデルを対象とします。 +* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 +* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。 +* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はLumina Image 2.0の学習では使用されません。 +* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。 +
+ +## 3. Preparation / 準備 + +The following files are required before starting training: + +1. **Training script:** `lumina_train_network.py` +2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model. +3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder. +4. **AutoEncoder (AE) file:** `.safetensors` file for the AE. +5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](link/to/dataset/config/doc).) In this document we use `my_lumina_dataset_config.toml` as an example. + +
+日本語 +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `lumina_train_network.py` +2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。 +3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。 +4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。 + * 例として`my_lumina_dataset_config.toml`を使用します。 +
+ +## 4. Running the Training / 学習の実行 + +Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied. + +Example command: + +```bash +accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ + --pretrained_model_name_or_path="lumina-image-2.safetensors" \ + --gemma2="gemma-2-2b.safetensors" \ + --ae="ae.safetensors" \ + --dataset_config="my_lumina_dataset_config.toml" \ + --output_dir="./output" \ + --output_name="my_lumina_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_lumina \ + --network_dim=8 \ + --network_alpha=8 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW" \ + --lr_scheduler="constant" \ + --timestep_sampling="nextdit_shift" \ + --discrete_flow_shift=6.0 \ + --model_prediction_type="raw" \ + --guidance_scale=4.0 \ + --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ + --use_flash_attn \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --cache_text_encoder_outputs +``` + +*(Write the command on one line or use `\` or `^` for line breaks.)* + +
+日本語 +学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。 + +以下に、基本的なコマンドライン実行例を示します。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ + --pretrained_model_name_or_path="lumina-image-2.safetensors" \ + --gemma2="gemma-2-2b.safetensors" \ + --ae="ae.safetensors" \ + --dataset_config="my_lumina_dataset_config.toml" \ + --output_dir="./output" \ + --output_name="my_lumina_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_lumina \ + --network_dim=8 \ + --network_alpha=8 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW" \ + --lr_scheduler="constant" \ + --timestep_sampling="nextdit_shift" \ + --discrete_flow_shift=6.0 \ + --model_prediction_type="raw" \ + --guidance_scale=4.0 \ + --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ + --use_flash_attn \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --cache_text_encoder_outputs +``` + +※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。 +
+ +### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 + +Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide. + +#### Model Options / モデル関連 + +* `--pretrained_model_name_or_path=""` **required** – Path to the Lumina Image 2.0 model. +* `--gemma2=""` **required** – Path to the Gemma2 text encoder `.safetensors` file. +* `--ae=""` **required** – Path to the AutoEncoder `.safetensors` file. + +#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ + +* `--gemma2_max_token_length=` – Max token length for Gemma2. Default varies by model. +* `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `sigma`. **Recommended: `nextdit_shift`** +* `--discrete_flow_shift=` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`. +* `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `sigma_scaled`. **Recommended: `raw`** +* `--guidance_scale=` – Guidance scale for training. **Recommended: `4.0`** +* `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` +* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn`. +* `--use_sage_attn` – Use Sage Attention. +* `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`. + +#### Memory and Speed / メモリ・速度関連 + +* `--blocks_to_swap=` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`. +* `--cache_text_encoder_outputs` – Cache Gemma2 outputs to reduce memory usage. +* `--cache_latents`, `--cache_latents_to_disk` – Cache AE outputs. +* `--fp8_base` – Use FP8 precision for the base model. + +#### Network Arguments / ネットワーク引数 + +For Lumina Image 2.0, you can specify different dimensions for various components: + +* `--network_args` can include: + * `"attn_dim=4"` – Attention dimension + * `"mlp_dim=4"` – MLP dimension + * `"mod_dim=4"` – Modulation dimension + * `"refiner_dim=4"` – Refiner blocks dimension + * `"embedder_dims=[4,4,4]"` – Embedder dimensions for x, t, and caption embedders + +#### Incompatible or Deprecated Options / 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0. + +
+日本語 +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。 + +#### モデル関連 + +* `--pretrained_model_name_or_path=""` **[必須]** + * 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。 +* `--gemma2=""` **[必須]** + * Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。 +* `--ae=""` **[必須]** + * AutoEncoderの`.safetensors`ファイルのパスを指定します。 + +#### Lumina Image 2.0 学習パラメータ + +* `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトはモデルによって異なります。 +* `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`sigma`です。**推奨: `nextdit_shift`** +* `--discrete_flow_shift=` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。 +* `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`sigma_scaled`です。**推奨: `raw`** +* `--guidance_scale=` – 学習時のガイダンススケールを指定します。**推奨: `4.0`** +* `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` +* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`が必要です。 +* `--use_sage_attn` – Sage Attentionを使用します。 +* `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。 + +#### メモリ・速度関連 + +* `--blocks_to_swap=` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。 +* `--cache_text_encoder_outputs` – Gemma2の出力をキャッシュしてメモリ使用量を削減します。 +* `--cache_latents`, `--cache_latents_to_disk` – AEの出力をキャッシュします。 +* `--fp8_base` – ベースモデルにFP8精度を使用します。 + +#### ネットワーク引数 + +Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます: + +* `--network_args` には以下を含めることができます: + * `"attn_dim=4"` – アテンション次元 + * `"mlp_dim=4"` – MLP次元 + * `"mod_dim=4"` – モジュレーション次元 + * `"refiner_dim=4"` – リファイナーブロック次元 + * `"embedder_dims=[4,4,4]"` – x、t、キャプションエンベッダーのエンベッダー次元 + +#### 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。 +
+ +### 4.2. Starting Training / 学習の開始 + +After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始). + +## 5. Using the Trained Model / 学習済みモデルの利用 + +When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes. + +## 6. Others / その他 + +`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`. + +### 6.1. Recommended Settings / 推奨設定 + +Based on the contributor's recommendations, here are the suggested settings for optimal training: + +**Model Files:** +* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) +* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) +* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX) + +**Key Parameters:** +* `--timestep_sampling="nextdit_shift"` +* `--discrete_flow_shift=6.0` +* `--model_prediction_type="raw"` +* `--guidance_scale=4.0` +* `--mixed_precision="bf16"` + +**System Prompts:** +* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."` +* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` + +**Sample Prompts:** +Sample prompts can include CFG truncate (`-ct`) and Renorm CFG (`-rc`) parameters: +* `-ct 0.25 -rc 1.0` (default values) + +
+日本語 +必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 + +学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。 + +`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。 + +### 6.1. 推奨設定 + +コントリビューターの推奨に基づく、最適な学習のための推奨設定: + +**モデルファイル:** +* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precisionリンク](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) または `lumina_2_model_bf16.safetensors` ([bf16リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) +* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) +* AutoEncoder: `ae.safetensors` ([リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (FLUXと同じ) + +**主要パラメータ:** +* `--timestep_sampling="nextdit_shift"` +* `--discrete_flow_shift=6.0` +* `--model_prediction_type="raw"` +* `--guidance_scale=4.0` +* `--mixed_precision="bf16"` + +**システムプロンプト:** +* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."` +* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` + +**サンプルプロンプト:** +サンプルプロンプトには CFG truncate (`-ct`) と Renorm CFG (`-rc`) パラメータを含めることができます: +* `-ct 0.25 -rc 1.0` (デフォルト値) +
\ No newline at end of file From 8a72f56c9f65d24646b3db8a902a74b077e07106 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Jul 2025 22:14:16 +0900 Subject: [PATCH 59/73] fix: clarify Flash Attention usage in lumina training guide --- docs/lumina_train_network.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md index 1c3794ab..2872f513 100644 --- a/docs/lumina_train_network.md +++ b/docs/lumina_train_network.md @@ -18,7 +18,6 @@ This guide assumes you already understand the basics of LoRA training. For commo
日本語 -ステータス:内容を一通り確認した `lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 @@ -100,7 +99,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ --model_prediction_type="raw" \ --guidance_scale=4.0 \ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ - --use_flash_attn \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ --mixed_precision="bf16" \ @@ -137,7 +135,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ --model_prediction_type="raw" \ --guidance_scale=4.0 \ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ - --use_flash_attn \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ --mixed_precision="bf16" \ @@ -167,8 +164,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md * `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `sigma_scaled`. **Recommended: `raw`** * `--guidance_scale=` – Guidance scale for training. **Recommended: `4.0`** * `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` -* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn`. -* `--use_sage_attn` – Use Sage Attention. +* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training. * `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`. #### Memory and Speed / メモリ・速度関連 @@ -214,8 +210,7 @@ For Lumina Image 2.0, you can specify different dimensions for various component * `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`sigma_scaled`です。**推奨: `raw`** * `--guidance_scale=` – 学習時のガイダンススケールを指定します。**推奨: `4.0`** * `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` -* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`が必要です。 -* `--use_sage_attn` – Sage Attentionを使用します。 +* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。 * `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。 #### メモリ・速度関連 From 1a9bf2ab56ef488e7cf1789cf7689977fdeece5d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 20:45:09 +0900 Subject: [PATCH 60/73] feat: add interactive mode for generating multiple images --- lumina_minimal_inference.py | 125 ++++++++++++++++++++++++++++++------ 1 file changed, 106 insertions(+), 19 deletions(-) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 4f915179..31362c00 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser: help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument( + "--interactive", + action="store_true", + help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する", + ) return parser @@ -294,9 +299,7 @@ if __name__ == "__main__": multiplier = 1.0 weights_sd = load_file(weights_file) - lora_model, _ = lora_lumina.create_network_from_weights( - multiplier, None, ae, [gemma2], model, weights_sd, True - ) + lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True) if args.merge_lora_weights: lora_model.merge_to([gemma2], model, weights_sd) @@ -304,25 +307,109 @@ if __name__ == "__main__": lora_model.apply_to([gemma2], model) info = lora_model.load_state_dict(weights_sd, strict=True) logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.to(device) + lora_model.set_multiplier(multiplier) lora_model.eval() lora_models.append(lora_model) - generate_image( - model, - gemma2, - ae, - args.prompt, - args.system_prompt, - args.seed, - args.image_width, - args.image_height, - args.steps, - args.guidance_scale, - args.negative_prompt, - args, - args.cfg_trunc_ratio, - args.renorm_cfg, - ) + if not args.interactive: + generate_image( + model, + gemma2, + ae, + args.prompt, + args.system_prompt, + args.seed, + args.image_width, + args.image_height, + args.steps, + args.guidance_scale, + args.negative_prompt, + args, + args.cfg_trunc_ratio, + args.renorm_cfg, + ) + else: + # Interactive mode loop + image_width = args.image_width + image_height = args.image_height + steps = args.steps + guidance_scale = args.guidance_scale + cfg_trunc_ratio = args.cfg_trunc_ratio + renorm_cfg = args.renorm_cfg + + print("Entering interactive mode.") + while True: + print( + "\nEnter prompt (or 'exit'). Options: --w --h --s --d --g --n --ctr --rcfg --m " + ) + user_input = input() + if user_input.lower() == "exit": + break + if not user_input: + continue + + # Parse options + options = user_input.split("--") + prompt = options[0].strip() + + # Set defaults for each generation + seed = None # New random seed each time unless specified + negative_prompt = args.negative_prompt # Reset to default + + for opt in options[1:]: + try: + opt = opt.strip() + if not opt: + continue + + key, value = (opt.split(None, 1) + [""])[:2] + + if key == "w": + image_width = int(value) + elif key == "h": + image_height = int(value) + elif key == "s": + steps = int(value) + elif key == "d": + seed = int(value) + elif key == "g": + guidance_scale = float(value) + elif key == "n": + negative_prompt = value if value != "-" else "" + elif key == "ctr": + cfg_trunc_ratio = float(value) + elif key == "rcfg": + renorm_cfg = float(value) + elif key == "m": + multipliers = value.split(",") + if len(multipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(multipliers[i].strip())) + else: + logger.warning(f"Unknown option: --{key}") + + except (ValueError, IndexError) as e: + logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}") + + generate_image( + model, + gemma2, + ae, + prompt, + args.system_prompt, + seed, + image_width, + image_height, + steps, + guidance_scale, + negative_prompt, + args, + cfg_trunc_ratio, + renorm_cfg, + ) logger.info("Done.") From 88dc3213a90fffce3586e2f87fa74cb106488f5a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 20:46:24 +0900 Subject: [PATCH 61/73] fix: support LoRA w/o TE for create_network_from_weights --- networks/lora_lumina.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index e4149b4a..0929e839 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -562,23 +562,26 @@ class LoRANetwork(torch.nn.Module): # Set dim/alpha to modules dim/alpha if modules_dim is not None and modules_alpha is not None: - # モジュール指定あり + # network from weights if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] + else: + dim = 0 # skip if not found - # Set dims to type_dims - if is_lumina and type_dims is not None: - identifier = [ - ("attention",), # attention layers - ("mlp",), # MLP layers - ("modulation",), # modulation layers - ("refiner",), # refiner blocks - ] - for i, d in enumerate(type_dims): - if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d # may be 0 for skip - break + else: + # Set dims to type_dims + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break # Drop blocks if we are only training some blocks if ( From 88960e63094bcb96fae318c526867fe409fade18 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 20:49:38 +0900 Subject: [PATCH 62/73] doc: update lumina LoRA training guide --- docs/lumina_train_network.md | 42 ++++++++++++++++-------------------- library/lumina_train_util.py | 4 ++-- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md index 2872f513..e811f68b 100644 --- a/docs/lumina_train_network.md +++ b/docs/lumina_train_network.md @@ -8,12 +8,12 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina `lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE). -This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md). +This guide assumes you already understand the basics of LoRA training. For common usage and options, see the train_network.py guide (to be documented). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md). **Prerequisites:** * The `sd-scripts` repository has been cloned and the Python environment is ready. -* A training dataset has been prepared. See the [Dataset Configuration Guide](link/to/dataset/config/doc). +* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md). * Lumina Image 2.0 model files for training are available.
@@ -21,12 +21,12 @@ This guide assumes you already understand the basics of LoRA training. For commo `lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 -このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 **前提条件:** * `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 -* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください) * 学習対象のLumina Image 2.0モデルファイルが準備できていること。
@@ -59,7 +59,14 @@ The following files are required before starting training: 2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model. 3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder. 4. **AutoEncoder (AE) file:** `.safetensors` file for the AE. -5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](link/to/dataset/config/doc).) In this document we use `my_lumina_dataset_config.toml` as an example. +5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example. + + +**Model Files:** +* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) +* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) +* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX) +
日本語 @@ -69,8 +76,11 @@ The following files are required before starting training: 2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。 3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。 4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。 -5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。 * 例として`my_lumina_dataset_config.toml`を使用します。 + +**モデルファイル** は英語ドキュメントの通りです。 +
## 4. Running the Training / 学習の実行 @@ -97,7 +107,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ --timestep_sampling="nextdit_shift" \ --discrete_flow_shift=6.0 \ --model_prediction_type="raw" \ - --guidance_scale=4.0 \ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ @@ -133,7 +142,6 @@ accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ --timestep_sampling="nextdit_shift" \ --discrete_flow_shift=6.0 \ --model_prediction_type="raw" \ - --guidance_scale=4.0 \ --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ @@ -158,11 +166,10 @@ Besides the arguments explained in the [train_network.py guide](train_network.md #### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ -* `--gemma2_max_token_length=` – Max token length for Gemma2. Default varies by model. +* `--gemma2_max_token_length=` – Max token length for Gemma2. Default is 256. * `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `sigma`. **Recommended: `nextdit_shift`** * `--discrete_flow_shift=` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`. * `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `sigma_scaled`. **Recommended: `raw`** -* `--guidance_scale=` – Guidance scale for training. **Recommended: `4.0`** * `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` * `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training. * `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`. @@ -204,11 +211,10 @@ For Lumina Image 2.0, you can specify different dimensions for various component #### Lumina Image 2.0 学習パラメータ -* `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトはモデルによって異なります。 +* `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。 * `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`sigma`です。**推奨: `nextdit_shift`** * `--discrete_flow_shift=` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。 * `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`sigma_scaled`です。**推奨: `raw`** -* `--guidance_scale=` – 学習時のガイダンススケールを指定します。**推奨: `4.0`** * `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` * `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。 * `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。 @@ -252,16 +258,10 @@ When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is Based on the contributor's recommendations, here are the suggested settings for optimal training: -**Model Files:** -* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) -* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) -* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX) - **Key Parameters:** * `--timestep_sampling="nextdit_shift"` * `--discrete_flow_shift=6.0` * `--model_prediction_type="raw"` -* `--guidance_scale=4.0` * `--mixed_precision="bf16"` **System Prompts:** @@ -284,16 +284,10 @@ Sample prompts can include CFG truncate (`-ct`) and Renorm CFG (`-rc`) parameter コントリビューターの推奨に基づく、最適な学習のための推奨設定: -**モデルファイル:** -* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precisionリンク](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) または `lumina_2_model_bf16.safetensors` ([bf16リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) -* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) -* AutoEncoder: `ae.safetensors` ([リンク](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (FLUXと同じ) - **主要パラメータ:** * `--timestep_sampling="nextdit_shift"` * `--discrete_flow_shift=6.0` * `--model_prediction_type="raw"` -* `--guidance_scale=4.0` * `--mixed_precision="bf16"` **システムプロンプト:** diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 45f22bc4..1cf9278a 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1042,8 +1042,8 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): "--gemma2_max_token_length", type=int, default=None, - help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev" - " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + help="maximum token length for Gemma2. if omitted, 256" + " / Gemma2の最大トークン長。省略された場合、256になります", ) parser.add_argument( From 999df5ec15c900a7dde3ac57c46db048ad988417 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 20:52:00 +0900 Subject: [PATCH 63/73] fix: update default values for timestep_sampling and model_prediction_type in training arguments --- docs/lumina_train_network.md | 8 ++++---- library/lumina_train_util.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md index e811f68b..45695e89 100644 --- a/docs/lumina_train_network.md +++ b/docs/lumina_train_network.md @@ -167,9 +167,9 @@ Besides the arguments explained in the [train_network.py guide](train_network.md #### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ * `--gemma2_max_token_length=` – Max token length for Gemma2. Default is 256. -* `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `sigma`. **Recommended: `nextdit_shift`** +* `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`** * `--discrete_flow_shift=` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`. -* `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `sigma_scaled`. **Recommended: `raw`** +* `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`** * `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` * `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training. * `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`. @@ -212,9 +212,9 @@ For Lumina Image 2.0, you can specify different dimensions for various component #### Lumina Image 2.0 学習パラメータ * `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。 -* `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`sigma`です。**推奨: `nextdit_shift`** +* `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`** * `--discrete_flow_shift=` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。 -* `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`sigma_scaled`です。**推奨: `raw`** +* `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`** * `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` * `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。 * `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。 diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 1cf9278a..0645a8ae 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1049,9 +1049,9 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], - default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。", + default="shift", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", @@ -1062,7 +1062,7 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--model_prediction_type", choices=["raw", "additive", "sigma_scaled"], - default="sigma_scaled", + default="raw", help="How to interpret and process the model prediction: " "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." " / モデル予測の解釈と処理方法:" From 30295c96686c90d4773e12fd5eb248e0a6bd406b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 21:00:27 +0900 Subject: [PATCH 64/73] fix: update parameter names for CFG truncate and Renorm CFG in documentation and code --- docs/lumina_train_network.md | 10 ++++++---- library/train_util.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md index 45695e89..cb3b600f 100644 --- a/docs/lumina_train_network.md +++ b/docs/lumina_train_network.md @@ -269,11 +269,12 @@ Based on the contributor's recommendations, here are the suggested settings for * High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` **Sample Prompts:** -Sample prompts can include CFG truncate (`-ct`) and Renorm CFG (`-rc`) parameters: -* `-ct 0.25 -rc 1.0` (default values) +Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters: +* `--ctr 0.25 --rcfg 1.0` (default values)
日本語 + 必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。 @@ -295,6 +296,7 @@ Sample prompts can include CFG truncate (`-ct`) and Renorm CFG (`-rc`) parameter * 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` **サンプルプロンプト:** -サンプルプロンプトには CFG truncate (`-ct`) と Renorm CFG (`-rc`) パラメータを含めることができます: -* `-ct 0.25 -rc 1.0` (デフォルト値) +サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます: +* `--ctr 0.25 --rcfg 1.0` (デフォルト値) +
\ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index 1d80bcd8..2e8e9c29 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6208,12 +6208,12 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["controlnet_image"] = m.group(1) continue - m = re.match(r"ct (.+)", parg, re.IGNORECASE) + m = re.match(r"ctr (.+)", parg, re.IGNORECASE) if m: prompt_dict["cfg_trunc_ratio"] = float(m.group(1)) continue - m = re.match(r"rc (.+)", parg, re.IGNORECASE) + m = re.match(r"rcfg (.+)", parg, re.IGNORECASE) if m: prompt_dict["renorm_cfg"] = float(m.group(1)) continue From 13ccfc39f860b9653b2f22ec1619a01ab8ffab90 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 13 Jul 2025 21:26:06 +0900 Subject: [PATCH 65/73] fix: update flow matching loss and variable names --- lumina_train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lumina_train.py b/lumina_train.py index 0a91f4a0..a333427d 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -294,7 +294,7 @@ def train(args): # load lumina nextdit = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, - loading_dtype, + weight_dtype, torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, use_flash_attn=args.use_flash_attn, @@ -494,6 +494,8 @@ def train(args): clean_memory_on_device(accelerator.device) + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 @@ -739,7 +741,7 @@ def train(args): with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = nextdit( - x=img, # image latents (B, C, H, W) + x=noisy_model_input, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to( @@ -751,8 +753,8 @@ def train(args): args, model_pred, noisy_model_input, sigmas ) - # flow matching loss: this is different from SD3 - target = noise - latents + # flow matching loss + target = latents - noise # calculate loss huber_c = train_util.get_huber_threshold_if_needed( From e0fcb5152a8c6f36d27b0f9f0e20e4ce75860c12 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Jul 2025 21:34:35 +0900 Subject: [PATCH 66/73] feat: support Neta Lumina all-in-one weights --- library/lumina_util.py | 40 ++++++++++++++++++++++++++++++------- lumina_minimal_inference.py | 4 ++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/library/lumina_util.py b/library/lumina_util.py index 452b242f..87853ef6 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -44,10 +44,21 @@ def load_lumina_model( """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to( + dtype + ) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "model.diffusion_model.cap_embedder.0.weight" in state_dict: + # remove "model.diffusion_model." prefix + filtered_state_dict = { + k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.") + } + state_dict = filtered_state_dict + info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") return model @@ -78,6 +89,13 @@ def load_ae( logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "vae.decoder.conv_in.bias" in sd: + # remove "vae." prefix + filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")} + sd = filtered_sd + info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae @@ -152,6 +170,16 @@ def load_gemma2( break # the model doesn't have annoying prefix sd[new_key] = sd.pop(key) + # Neta-Lumina support + if "text_encoders.gemma2_2b.logit_scale" in sd: + # remove "text_encoders.gemma2_2b.transformer.model." prefix + filtered_sd = { + k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v + for k, v in sd.items() + if k.startswith("text_encoders.gemma2_2b.transformer.model.") + } + sd = filtered_sd + info = gemma2.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Gemma2: {info}") return gemma2 @@ -173,7 +201,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: return x - DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = { # Embedding layers "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight", @@ -211,11 +238,11 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): # Handle block-specific patterns - if '().' in diff_key: + if "()." in diff_key: for block_idx in range(num_double_blocks): - block_alpha_key = alpha_key.replace('().', f'{block_idx}.') - block_diff_key = diff_key.replace('().', f'{block_idx}.') - + block_alpha_key = alpha_key.replace("().", f"{block_idx}.") + block_diff_key = diff_key.replace("().", f"{block_idx}.") + # Search for and convert block-specific keys for input_key, value in list(sd.items()): if input_key == block_diff_key: @@ -228,6 +255,5 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict else: print(f"Not found: {diff_key}") - logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") return new_sd diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 31362c00..d829616b 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -231,13 +231,13 @@ def setup_parser() -> argparse.ArgumentParser: "--cfg_trunc_ratio", type=float, default=0.25, - help="TBD", + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the last 25% of timesteps will be guided.", ) parser.add_argument( "--renorm_cfg", type=float, default=1.0, - help="TBD", + help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.", ) parser.add_argument( "--use_flash_attn", From 25771a5180a134190c0e9b540ee5a074ff70e6cd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Jul 2025 21:53:13 +0900 Subject: [PATCH 67/73] fix: update help text for cfg_trunc_ratio argument --- lumina_minimal_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index d829616b..691ee418 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -231,7 +231,7 @@ def setup_parser() -> argparse.ArgumentParser: "--cfg_trunc_ratio", type=float, default=0.25, - help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the last 25% of timesteps will be guided.", + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25% of timesteps will be guided.", ) parser.add_argument( "--renorm_cfg", From c0c36a4e2ffb9a8438f490ff3d0deca8a03bbd26 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Jul 2025 21:58:03 +0900 Subject: [PATCH 68/73] fix: remove duplicated latent normalization in decoding --- lumina_minimal_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 691ee418..87dc9a19 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -158,7 +158,7 @@ def generate_image( # 5. Decode latents # logger.info("Decoding image...") - latents = latents / ae.scale_factor + ae.shift_factor + # latents = latents / ae.scale_factor + ae.shift_factor with torch.no_grad(): image = ae.decode(latents.to(ae_dtype)) image = (image / 2 + 0.5).clamp(0, 1) From a7b33f320495afa39e353e0c583accf15ad9cb20 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 15 Jul 2025 22:36:46 -0400 Subject: [PATCH 69/73] Fix alphas cumprod after add_noise for DDIMScheduler --- library/train_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 36d419fd..285870fa 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6008,6 +6008,8 @@ def get_noise_noisy_latents_and_timesteps( else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() + return noise, noisy_latents, timesteps From 3adbbb6e3347b9a0da852a65a85d58a5da777443 Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Wed, 16 Jul 2025 16:09:20 -0400 Subject: [PATCH 70/73] Add note about why we are moving it --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 285870fa..165d873b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6008,6 +6008,7 @@ def get_noise_noisy_latents_and_timesteps( else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() return noise, noisy_latents, timesteps From aec7e160949d900f709fe3c10a8602362dc097f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 21 Jul 2025 13:14:59 +0900 Subject: [PATCH 71/73] feat: add an option to add system prompt for negative in lumina inference --- lumina_minimal_inference.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 87dc9a19..47d6d30b 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -48,7 +48,7 @@ def generate_image( steps: int, guidance_scale: float, negative_prompt: Optional[str], - args, + args: argparse.Namespace, cfg_trunc_ratio: float = 0.25, renorm_cfg: float = 1.0, ): @@ -88,7 +88,9 @@ def generate_image( with torch.no_grad(): gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) + tokens_and_masks = tokenize_strategy.tokenize( + negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt + ) with torch.no_grad(): neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) @@ -215,6 +217,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')") parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM") parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model") + parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt") parser.add_argument( "--gemma2_max_token_length", type=int, @@ -231,7 +234,7 @@ def setup_parser() -> argparse.ArgumentParser: "--cfg_trunc_ratio", type=float, default=0.25, - help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25% of timesteps will be guided.", + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.", ) parser.add_argument( "--renorm_cfg", From d300f19045e8c87bd5dd2dcd9f3cf84571f80206 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 21 Jul 2025 13:15:09 +0900 Subject: [PATCH 72/73] docs: update Lumina training guide to include inference script and options --- docs/lumina_train_network.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md index cb3b600f..5f2fda17 100644 --- a/docs/lumina_train_network.md +++ b/docs/lumina_train_network.md @@ -1,5 +1,3 @@ -Status: reviewed - # LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository. @@ -198,6 +196,7 @@ For Lumina Image 2.0, you can specify different dimensions for various component
日本語 + [`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。 #### モデル関連 @@ -250,6 +249,18 @@ After setting the required arguments, run the command to begin training. The ove When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes. +### Inference with scripts in this repository / このリポジトリのスクリプトを使用した推論 + +The inference script is also available. The script is `lumina_minimal_inference.py`. See `--help` for options. + +``` +python lumina_minimal_inference.py --pretrained_model_name_or_path path/to/lumina.safetensors --gemma2_path path/to/gemma.safetensors" --ae_path path/to/flux_ae.safetensors --output_dir path/to/output_dir --offload --seed 1234 --prompt "Positive prompt" --system_prompt "You are an assistant designed to generate high-quality images based on user prompts." --negative_prompt "negative prompt" +``` + +`--add_system_prompt_to_negative_prompt` option can be used to add the system prompt to the negative prompt. + +`--lora_weights` option can be used to specify the LoRA weights file, and optional multiplier (like `path;1.0`). + ## 6. Others / その他 `lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`. @@ -279,6 +290,8 @@ Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) param 学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。 +当リポジトリ内の推論スクリプトを用いて推論することも可能です。スクリプトは`lumina_minimal_inference.py`です。オプションは`--help`で確認できます。記述例は英語版のドキュメントをご確認ください。 + `lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。 ### 6.1. 推奨設定 From 518545bffbd8b2629944b9d3c65e6e77f167e7ce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 21 Jul 2025 13:16:42 +0900 Subject: [PATCH 73/73] docs: add support information for Lumina-Image 2.0 in recent updates --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 149f453b..b6365644 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,10 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed ### Recent Updates +Jul 21, 2025: +- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions. + - Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details. + Jul 10, 2025: - [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.