From 025cca699ba0ee05b91d37e5b7779ec28d076620 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Feb 2025 01:29:18 -0500 Subject: [PATCH] Fix samples, LoRA training. Add system prompt, use_flash_attn --- library/config_util.py | 6 + library/lumina_models.py | 200 +++++++++++++----------- library/lumina_train_util.py | 289 ++++++++++++++++++++++++++--------- library/lumina_util.py | 86 +++++------ library/sd3_train_utils.py | 259 ++++++++++++++++++++++++++----- library/strategy_base.py | 84 ++++++++-- library/strategy_lumina.py | 153 +++++++++++++++---- library/train_util.py | 21 ++- lumina_train_network.py | 171 +++++++++------------ train_network.py | 5 +- 10 files changed, 888 insertions(+), 386 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6..ca14dfb1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -106,6 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 + system_prompt: Optional[str] = None @dataclass @@ -196,6 +198,7 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "system_prompt": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +244,7 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "system_prompt": str, } # options handled by argparse but not handled by user config @@ -526,6 +530,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + system_prompt: {dataset.system_prompt} """) if dataset.enable_bucket: @@ -559,6 +564,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} custom_attributes: {subset.custom_attributes} + system_prompt: {subset.system_prompt} """), " ") if is_dreambooth: diff --git a/library/lumina_models.py b/library/lumina_models.py index 36c3b979..f819b68f 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -14,14 +14,19 @@ from typing import List, Optional, Tuple from dataclasses import dataclass from einops import rearrange -from flash_attn import flash_attn_varlen_func -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F +try: + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass + try: from apex.normalization import FusedRMSNorm as RMSNorm except: @@ -75,7 +80,15 @@ class LuminaParams: @classmethod def get_7b_config(cls) -> "LuminaParams": """Returns the configuration for the 7B parameter model""" - return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512]) + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512], + ) class GradientCheckpointMixin(nn.Module): @@ -248,6 +261,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + use_flash_attn=False, ): """ Initialize the Attention module. @@ -286,7 +300,7 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - self.flash_attn = False + self.use_flash_attn = use_flash_attn # self.attention_processor = xformers.ops.memory_efficient_attention self.attention_processor = F.scaled_dot_product_attention @@ -294,35 +308,63 @@ class JointAttention(nn.Module): def set_attention_processor(self, attention_processor): self.attention_processor = attention_processor - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: + def forward( + self, + x: Tensor, + x_mask: Tensor, + freqs_cis: Tensor, + ) -> Tensor: """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. + x: + x_mask: + freqs_cis: """ - with torch.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = apply_rope(xq, freqs_cis=freqs_cis) + xk = apply_rope(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if self.use_flash_attn: + output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = ( + self.attention_processor( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + return self.out(output) # copied from huggingface modeling_llama.py def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): @@ -377,46 +419,17 @@ class JointAttention(nn.Module): (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) - def forward( + def flash_attn( self, - x: Tensor, + q: Tensor, + k: Tensor, + v: Tensor, x_mask: Tensor, - freqs_cis: Tensor, + softmax_scale, ) -> Tensor: - """ + bsz, seqlen, _, _ = q.shape - Args: - x: - x_mask: - freqs_cis: - - Returns: - - """ - bsz, seqlen, _ = x.shape - dtype = x.dtype - - xq, xk, xv = torch.split( - self.qkv(x), - [ - self.n_local_heads * self.head_dim, - self.n_local_kv_heads * self.head_dim, - self.n_local_kv_heads * self.head_dim, - ], - dim=-1, - ) - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xq = self.q_norm(xq) - xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) - xq, xk = xq.to(dtype), xk.to(dtype) - - softmax_scale = math.sqrt(1 / self.head_dim) - - if self.flash_attn: + try: # begin var_len flash attn ( query_states, @@ -425,7 +438,7 @@ class JointAttention(nn.Module): indices_q, cu_seq_lens, max_seq_lens, - ) = self._upad_input(xq, xk, xv, x_mask, seqlen) + ) = self._upad_input(q, k, v, x_mask, seqlen) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -445,27 +458,12 @@ class JointAttention(nn.Module): output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) # end var_len_flash_attn - else: - n_rep = self.n_local_heads // self.n_local_kv_heads - if n_rep >= 1: - xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - - output = ( - self.attention_processor( - xq.permute(0, 2, 1, 3), - xk.permute(0, 2, 1, 3), - xv.permute(0, 2, 1, 3), - attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), - scale=softmax_scale, - ) - .permute(0, 2, 1, 3) - .to(dtype) + return output + except NameError as e: + raise RuntimeError( + f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) - output = output.flatten(-2) - return self.out(output) - def apply_rope( x_in: torch.Tensor, @@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin): norm_eps: float, qk_norm: bool, modulation=True, + use_flash_attn=False, ) -> None: """ Initialize a TransformerBlock. @@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn) self.feed_forward = FeedForward( dim=dim, hidden_dim=4 * dim, @@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin): class RopeEmbedder: - def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]): + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + ): super().__init__() self.theta = theta self.axes_dims = axes_dims @@ -750,6 +754,7 @@ class NextDiT(nn.Module): cap_feat_dim: int = 5120, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512], + use_flash_attn=False, ) -> None: """ Initialize the NextDiT model. @@ -803,6 +808,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=False, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -828,6 +834,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_refiner_layers) ] @@ -848,6 +855,7 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + use_flash_attn=use_flash_attn, ) for layer_id in range(n_layers) ] @@ -988,8 +996,20 @@ class NextDiT(nn.Module): freqs_cis = self.rope_embedder(position_ids) # Create separate rotary embeddings for captions and images - cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) - img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + cap_freqs_cis = torch.zeros( + bsz, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + bsz, + image_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 9dac9c9f..414b2849 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1,21 +1,28 @@ +import inspect +import enum import argparse import math import os import numpy as np import time -from typing import Callable, Dict, List, Optional, Tuple, Any +from typing import Callable, Dict, List, Optional, Tuple, Any, Union import torch from torch import Tensor +from torchdiffeq import odeint from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm from PIL import Image from safetensors.torch import save_file +from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util +from library.flux_models import AutoEncoder from library.device_utils import init_ipex, clean_memory_on_device from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.lumina_dpm_solver import NoiseScheduleFlow, DPM_Solver +import library.lumina_path as path init_ipex() @@ -162,12 +169,12 @@ def sample_image_inference( args: argparse.Namespace, nextdit: lumina_models.NextDiT, gemma2_model: Gemma2Model, - vae: torch.nn.Module, + vae: AutoEncoder, save_dir: str, prompt_dict: Dict[str, str], epoch: int, global_step: int, - sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]], + sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]], prompt_replacement: Optional[Tuple[str, str]] = None, controlnet=None, ): @@ -179,12 +186,12 @@ def sample_image_inference( args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model gemma2_model (Gemma2Model): Gemma2 model - vae (torch.nn.Module): VAE model + vae (AutoEncoder): VAE model save_dir (str): Directory to save images prompt_dict (Dict[str, str]): Prompt dictionary epoch (int): Epoch number steps (int): Number of steps to run - sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. Returns: @@ -192,16 +199,19 @@ def sample_image_inference( """ assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 38) - width = prompt_dict.get("width", 1024) - height = prompt_dict.get("height", 1024) - guidance_scale: int = prompt_dict.get("scale", 3.5) - seed: int = prompt_dict.get("seed", None) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + guidance_scale = float(prompt_dict.get("scale", 3.5)) + seed = prompt_dict.get("seed", None) controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") negative_prompt: str = prompt_dict.get("negative_prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -213,10 +223,10 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - height = max(64, height - height % 16) # round to divisible by 16 - width = max(64, width - width % 16) # round to divisible by 16 + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"negative_prompt: {negative_prompt}") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") @@ -232,46 +242,51 @@ def sample_image_inference( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - gemma2_conds = [] + system_prompt = args.system_prompt or "" + + # Apply system prompt to prompts + prompt = system_prompt + prompt + negative_prompt = system_prompt + negative_prompt + + # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: gemma2_conds = sample_prompts_gemma2_outputs[prompt] logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + + # Load sample prompts from Gemma 2 if gemma2_model is not None: logger.info(f"Encoding prompt with Gemma2: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) - encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) - # if gemma2_conds is not cached, use encoded_gemma2_conds - if len(gemma2_conds) == 0: - gemma2_conds = encoded_gemma2_conds - else: - # if encoded_gemma2_conds is not None, update cached gemma2_conds - for i in range(len(encoded_gemma2_conds)): - if encoded_gemma2_conds[i] is not None: - gemma2_conds[i] = encoded_gemma2_conds[i] + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) # Unpack Gemma2 outputs gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds # sample image weight_dtype = vae.dtype # TOFO give dtype as argument latent_height = height // 8 latent_width = width // 8 + latent_channels = 16 noise = torch.randn( 1, - 16, + latent_channels, latent_height, latent_width, device=accelerator.device, dtype=weight_dtype, generator=generator, ) - # Prompts are paired positive/negative - noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) - # img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) - gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) + scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -280,16 +295,25 @@ def sample_image_inference( # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(): - x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale) + x = denoise( + scheduler, + nextdit, + noise, + gemma2_hidden_states, + gemma2_attn_mask.to(accelerator.device), + neg_gemma2_hidden_states, + neg_gemma2_attn_mask.to(accelerator.device), + timesteps=timesteps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) - # x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) - - # latent to image + # Latent to image clean_memory_on_device(accelerator.device) org_vae_device = vae.device # will be on cpu vae.to(accelerator.device) # distributed_state.device is same as accelerator.device with accelerator.autocast(): - x = vae.decode(x) + x = vae.decode((x / vae.scale_factor) + vae.shift_factor) vae.to(org_vae_device) clean_memory_on_device(accelerator.device) @@ -317,30 +341,25 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def time_shift(mu: float, sigma: float, t: Tensor): - """ - Get time shift - - Args: - mu (float): mu value. - sigma (float): sigma value. - t (Tensor): timestep. - - Return: - float: time shift - """ - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) +def time_shift(mu: float, sigma: float, t: torch.Tensor): + # the following implementation was original for t=0: clean / t=1: noise + # Since we adopt the reverse, the 1-t operations are needed + t = 1 - t + t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + t = 1 - t + return t -def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: +def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]: """ Get linear function Args: - x1 (float, optional): x1 value. Defaults to 256. - y1 (float, optional): y1 value. Defaults to 0.5. - x2 (float, optional): x2 value. Defaults to 4096. - y2 (float, optional): y2 value. Defaults to 1.15. + image_seq_len, + x1 base_seq_len: int = 256, + y2 max_seq_len: int = 4096, + y1 base_shift: float = 0.5, + y2 max_shift: float = 1.15, Return: Callable[[float], float]: linear function @@ -370,51 +389,164 @@ def get_schedule( Return: List[float]: timesteps schedule """ - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = torch.linspace(1, 1 / num_steps, num_steps) # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +) -> Tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + def denoise( - model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0 + scheduler, + model: lumina_models.NextDiT, + img: Tensor, + txt: Tensor, + txt_mask: Tensor, + neg_txt: Tensor, + neg_txt_mask: Tensor, + timesteps: Union[List[float], torch.Tensor], + num_inference_steps: int = 38, + guidance_scale: float = 4.0, + cfg_trunc_ratio: float = 1.0, + cfg_normalization: bool = True, ): """ Denoise an image using the NextDiT model. Args: + scheduler (): + Noise scheduler model (lumina_models.NextDiT): The NextDiT model instance. - img (Tensor): The input image tensor. - txt (Tensor): The input text tensor. - txt_mask (Tensor): The input text mask tensor. - timesteps (List[float]): A list of timesteps for the denoising process. - guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0. + img (Tensor): + The input image latent tensor. + txt (Tensor): + The input text tensor. + txt_mask (Tensor): + The input text mask tensor. + neg_txt (Tensor): + The negative input txt tensor + neg_txt_mask (Tensor): + The negative input text mask tensor. + timesteps (List[Union[float, torch.FloatTensor]]): + A list of timesteps for the denoising process. + guidance_scale (float, optional): + The guidance scale for the denoising process. Defaults to 4.0. + cfg_trunc_ratio (float, optional): + The ratio of the timestep interval to apply normalization-based guidance scale. + cfg_normalization (bool, optional): + Whether to apply normalization-based guidance scale. Returns: - img (Tensor): Denoised tensor + img (Tensor): Denoised latent tensor """ - for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - # model.prepare_block_swap_before_forward() - # block_samples = None - # block_single_samples = None - pred = model.forward_with_cfg( - x=img, # image latents (B, C, H, W) - t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期 + + for i, t in enumerate(tqdm(timesteps)): + # compute whether apply classifier-free truncation on this timestep + do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(img.shape[0]).to(model.device) + + noise_pred_cond = model( + img, + current_timestep, cap_feats=txt, # Gemma2的hidden states作为caption features cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask - cfg_scale=guidance, ) - img = img + (t_prev - t_curr) * pred + if not do_classifier_free_truncation: + noise_pred_uncond = model( + img, + current_timestep, + cap_feats=neg_txt, # Gemma2的hidden states作为caption features + cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # apply normalization after classifier-free guidance + if cfg_normalization: + cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + img_dtype = img.dtype + + if img.dtype != img_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + img = img.to(img_dtype) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = -noise_pred + img = scheduler.step(noise_pred, t, img, return_dict=False)[0] - # model.prepare_block_swap_before_forward() return img @@ -754,3 +886,14 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。", + ) + parser.add_argument( + "--system_prompt", + type=str, + default="You are an assistant designed to generate high-quality images based on user prompts. ", + help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py index f404e775..d9c89938 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -20,11 +20,13 @@ logger = logging.getLogger(__name__) MODEL_VERSION_LUMINA_V2 = "lumina2" + def load_lumina_model( ckpt_path: str, - dtype: torch.dtype, + dtype: Optional[torch.dtype], device: torch.device, disable_mmap: bool = False, + use_flash_attn: bool = False, ): """ Load the Lumina model from the checkpoint path. @@ -34,22 +36,22 @@ def load_lumina_model( dtype (torch.dtype): The data type for the model. device (torch.device): The device to load the model on. disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False. Returns: model (lumina_models.NextDiT): The loaded model. """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - state_dict = load_safetensors( - ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype - ) + state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") return model + def load_ae( ckpt_path: str, dtype: torch.dtype, @@ -74,9 +76,7 @@ def load_ae( ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype - ) + sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae @@ -104,37 +104,35 @@ def load_gemma2( """ logger.info("Building Gemma2") GEMMA2_CONFIG = { - "_name_or_path": "google/gemma-2-2b", - "architectures": [ - "Gemma2Model" - ], - "attention_bias": False, - "attention_dropout": 0.0, - "attn_logit_softcapping": 50.0, - "bos_token_id": 2, - "cache_implementation": "hybrid", - "eos_token_id": 1, - "final_logit_softcapping": 30.0, - "head_dim": 256, - "hidden_act": "gelu_pytorch_tanh", - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2304, - "initializer_range": 0.02, - "intermediate_size": 9216, - "max_position_embeddings": 8192, - "model_type": "gemma2", - "num_attention_heads": 8, - "num_hidden_layers": 26, - "num_key_value_heads": 4, - "pad_token_id": 0, - "query_pre_attn_scalar": 256, - "rms_norm_eps": 1e-06, - "rope_theta": 10000.0, - "sliding_window": 4096, - "torch_dtype": "float32", - "transformers_version": "4.44.2", - "use_cache": True, - "vocab_size": 256000 + "_name_or_path": "google/gemma-2-2b", + "architectures": ["Gemma2Model"], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000, } config = Gemma2Config(**GEMMA2_CONFIG) @@ -145,9 +143,7 @@ def load_gemma2( sd = state_dict else: logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype - ) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) for key in list(sd.keys()): new_key = key.replace("model.", "") @@ -159,6 +155,7 @@ def load_gemma2( logger.info(f"Loaded Gemma2: {info}") return gemma2 + def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 @@ -174,6 +171,7 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + DIFFUSERS_TO_ALPHA_VLLM_MAP = { # Embedding layers "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], @@ -224,9 +222,7 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict for block_idx in range(num_double_blocks): if str(block_idx) in key: converted = pattern.replace("()", str(block_idx)) - new_key = key.replace( - converted, replacement.replace("()", str(block_idx)) - ) + new_key = key.replace(converted, replacement.replace("()", str(block_idx))) break if new_key == key: diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c4079884..6a4b39b3 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -610,6 +610,21 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import BaseOutput +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ @@ -649,22 +664,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self, num_train_timesteps: int = 1000, shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -690,6 +732,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -709,10 +754,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) - sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -720,7 +786,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -730,18 +826,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -807,7 +934,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -823,30 +954,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + prev_sample = sample + (sigma_next - sigma) * model_output - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - - # if self.config.prediction_type == "vector_field": - - denoised = sample - model_output * sigma - # 2. Convert to an ODE derivative - derivative = (sample - denoised) / sigma_hat - - dt = self.sigmas[self.step_index + 1] - sigma_hat - - prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -858,6 +969,86 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def __len__(self): return self.config.num_train_timesteps diff --git a/library/strategy_base.py b/library/strategy_base.py index 358e42f1..fad79682 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -2,7 +2,7 @@ import os import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -430,9 +430,21 @@ class LatentsCachingStrategy: bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, multi_resolution: bool = False, - ): + ) -> bool: + """ + Args: + latents_stride: stride of latents + bucket_reso: resolution of the bucket + npz_path: path to the npz file + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + multi_resolution: whether to use multi-resolution latents + + Returns: + bool + """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -451,7 +463,7 @@ class LatentsCachingStrategy: return False if flip_aug and "latents_flipped" + key_reso_suffix not in npz: return False - if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False except Exception as e: logger.error(f"Error loading file: {npz_path}") @@ -462,22 +474,35 @@ class LatentsCachingStrategy: # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( self, - encode_by_vae, - vae_device, - vae_dtype, + encode_by_vae: Callable, + vae_device: torch.device, + vae_dtype: torch.dtype, image_infos: List, flip_aug: bool, - alpha_mask: bool, + apply_alpha_mask: bool, random_crop: bool, multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + + Args: + encode_by_vae: function to encode images by VAE + vae_device: device to use for VAE + vae_dtype: dtype to use for VAE + image_infos: list of ImageInfo + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + random_crop: whether to random crop images + multi_resolution: whether to use multi-resolution latents + + Returns: + None """ from library import train_util # import here to avoid circular import img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop + image_infos, apply_alpha_mask, random_crop ) img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) @@ -519,12 +544,40 @@ class LatentsCachingStrategy: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ for SD/SDXL + + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Args: + latents_stride (Optional[int]): Stride for latents. If None, load all latents. + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask + """ if latents_stride is None: key_reso_suffix = "" else: @@ -552,6 +605,19 @@ class LatentsCachingStrategy: alpha_mask=None, key_reso_suffix="", ): + """ + Args: + npz_path (str): Path to the npz file. + latents_tensor (torch.Tensor): Latent tensor + original_size (List[int]): Original size of the image + crop_ltrb (List[int]): Crop left top right bottom + flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor + alpha_mask (Optional[torch.Tensor]): Alpha mask + key_reso_suffix (str): Key resolution suffix + + Returns: + None + """ kwargs = {} if os.path.exists(npz_path): diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 0a6a7f29..5d6e100f 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -3,13 +3,13 @@ import os from typing import Any, List, Optional, Tuple, Union import torch -from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast +from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast from library import train_util from library.strategy_base import ( LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy, - TextEncoderOutputsCachingStrategy + TextEncoderOutputsCachingStrategy, ) import numpy as np from library.utils import setup_logging @@ -37,21 +37,38 @@ class LuminaTokenizeStrategy(TokenizeStrategy): else: self.max_length = max_length - def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: + def tokenize( + self, text: Union[str, List[str]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + token input ids, attention_masks + """ text = [text] if isinstance(text, str) else text encodings = self.tokenizer( text, max_length=self.max_length, return_tensors="pt", - padding=True, + padding="max_length", pad_to_multiple_of=8, - truncation=True, ) - return [encodings.input_ids, encodings.attention_mask] + return (encodings.input_ids, encodings.attention_mask) def tokenize_with_weights( self, text: str | List[str] ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + token input ids, attention_masks, weights + """ # Gemma doesn't support weighted prompts, return uniform weights tokens, attention_masks = self.tokenize(text) weights = [torch.ones_like(t) for t in tokens] @@ -66,9 +83,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], + tokens: Tuple[torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ text_encoder = models[0] + assert isinstance(text_encoder, Gemma2Model) input_ids, attention_masks = tokens outputs = text_encoder( @@ -84,9 +112,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], - weights_list: List[torch.Tensor], + tokens: Tuple[torch.Tensor, torch.Tensor], + weights: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + weights_list (List[torch.Tensor]): Currently unused + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ # For simplicity, use uniform weighting return self.encode_tokens(tokenize_strategy, models, tokens) @@ -114,7 +153,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX ) - def is_disk_cached_outputs_expected(self, npz_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + """ + Args: + npz_path (str): Path to the npz file. + + Returns: + bool: True if the npz file is expected to be cached. + """ if not self.cache_to_disk: return False if not os.path.exists(npz_path): @@ -141,7 +187,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) Load outputs from a npz file Returns: - List[np.ndarray]: hidden_state, input_ids, attention_mask + List[np.ndarray]: hidden_state, input_ids, attention_mask """ data = np.load(npz_path) hidden_state = data["hidden_state"] @@ -151,53 +197,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) def cache_batch_outputs( self, - tokenize_strategy: LuminaTokenizeStrategy, + tokenize_strategy: TokenizeStrategy, models: List[Any], - text_encoding_strategy: LuminaTextEncodingStrategy, - infos: List, - ): - lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( - text_encoding_strategy - ) - captions = [info.caption for info in infos] + text_encoding_strategy: TextEncodingStrategy, + batch: List[train_util.ImageInfo], + ) -> None: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + text_encoding_strategy (LuminaTextEncodingStrategy): + infos (List): List of image_info + + Returns: + None + """ + assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) + assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) + + captions = [info.system_prompt or "" + info.caption for info in batch] if self.is_weighted: - tokens, weights_list = tokenize_strategy.tokenize_with_weights( - captions + tokens, attention_masks, weights_list = ( + tokenize_strategy.tokenize_with_weights(captions) ) with torch.no_grad(): - hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens, weights_list + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, + ) ) else: tokens = tokenize_strategy.tokenize(captions) with torch.no_grad(): - hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + ) ) if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() hidden_state = hidden_state.cpu().numpy() - attention_mask = attention_masks.cpu().numpy() - input_ids = tokens.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() # (B, S) + input_ids = input_ids.cpu().numpy() # (B, S) - - for i, info in enumerate(infos): + for i, info in enumerate(batch): hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] + assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}" + if self.cache_to_disk: np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, attention_mask=attention_mask_i, - input_ids=input_ids_i + input_ids=input_ids_i, ) else: - info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] + info.text_encoder_outputs = [ + hidden_state_i, + attention_mask_i, + input_ids_i, + ] class LuminaLatentsCachingStrategy(LatentsCachingStrategy): @@ -227,7 +295,14 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): npz_path: str, flip_aug: bool, alpha_mask: bool, - ): + ) -> bool: + """ + Args: + bucket_reso (Tuple[int, int]): The resolution of the bucket. + npz_path (str): Path to the npz file. + flip_aug (bool): Whether to flip the image. + alpha_mask (bool): Whether to apply + """ return self._default_is_disk_cached_latents_expected( 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True ) @@ -241,6 +316,20 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): Optional[np.ndarray], Optional[np.ndarray], ]: + """ + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet + """ return self._default_load_latents_from_disk( 8, npz_path, bucket_reso ) # support multi-resolution diff --git a/library/train_util.py b/library/train_util.py index 4eccc4a0..230b2c4b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -195,7 +195,7 @@ class ImageInfo: self.latents_flipped: Optional[torch.Tensor] = None self.latents_npz: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = ( None # crop left top right bottom in original pixel size, not latents size ) self.cond_img_path: Optional[str] = None @@ -211,6 +211,8 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.system_prompt: Optional[str] = None + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -434,6 +436,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -464,6 +467,8 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.system_prompt = system_prompt + class DreamBoothSubset(BaseSubset): def __init__( @@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.is_reg = is_reg @@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.metadata_file = metadata_file @@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + system_prompt: Optional[str] = None ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + system_prompt=system_prompt ) self.conditioning_data_dir = conditioning_data_dir @@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: + system_prompt = subset.system_prompt or "" caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: # caption_layer = [] # for layer in self.XTI_layers: @@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset): num_train_images = 0 num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] + for subset in subsets: num_repeats = subset.num_repeats if self.is_training_dataset else 1 if num_repeats < 1: @@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset): num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: @@ -2967,7 +2980,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size diff --git a/lumina_train_network.py b/lumina_train_network.py index 81acfb51..adbf834c 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -1,17 +1,17 @@ import argparse import copy -import math -import random -from typing import Any, Optional, Union, Tuple +from typing import Any, Tuple import torch -from torch import Tensor -from accelerate import Accelerator from library.device_utils import clean_memory_on_device, init_ipex init_ipex() +from torch import Tensor +from accelerate import Accelerator + + import train_network from library import ( lumina_models, @@ -40,10 +40,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def assert_extra_args(self, args, train_dataset_group, val_dataset_group): super().assert_extra_args(args, train_dataset_group, val_dataset_group) - if ( - args.cache_text_encoder_outputs_to_disk - and not args.cache_text_encoder_outputs - ): + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning("Enabling cache_text_encoder_outputs due to disk caching") args.cache_text_encoder_outputs = True @@ -59,17 +56,14 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): model = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, loading_dtype, - "cpu", + torch.device("cpu"), disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, ) if args.fp8_base: # check dtype of model - if ( - model.dtype == torch.float8_e4m3fnuz - or model.dtype == torch.float8_e5m2 - or model.dtype == torch.float8_e5m2fnuz - ): + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 Lumina 2 model") @@ -92,17 +86,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model def get_tokenize_strategy(self, args): - return strategy_lumina.LuminaTokenizeStrategy( - args.gemma2_max_token_length, args.tokenizer_cache_dir - ) + return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): return [tokenize_strategy.tokenizer] def get_latents_caching_strategy(self, args): - return strategy_lumina.LuminaLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, False - ) + return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) def get_text_encoding_strategy(self, args): return strategy_lumina.LuminaTextEncodingStrategy() @@ -144,15 +134,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to( - accelerator.device, dtype=weight_dtype - ) # always not fp8 + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 if text_encoders[0].dtype == torch.float8_e4m3fn: # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8( - 1, text_encoders[1], text_encoders[1].dtype, weight_dtype - ) + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) else: # otherwise, we need to convert it to target dtype text_encoders[0].to(weight_dtype) @@ -162,35 +148,36 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # cache sample prompts if args.sample_prompts is not None: - logger.info( - f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" - ) + logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}") tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + system_prompt = args.system_prompt or "" sample_prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = ( - {} - ) # key: prompt, value: text encoder outputs + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: - prompts = [prompt_dict.get("prompt", ""), - prompt_dict.get("negative_prompt", "")] - logger.info( - f"cache Text Encoder outputs for prompt: {prompts[0]}" - ) - tokens_and_masks = tokenize_strategy.tokenize(prompts) - sample_prompts_te_outputs[prompts[0]] = ( - text_encoding_strategy.encode_tokens( + prompts = [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ] + for prompt in prompts: + prompt = system_prompt + prompt + if prompt in sample_prompts_te_outputs: + continue + + logger.info(f"cache Text Encoder outputs for prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( tokenize_strategy, text_encoders, tokens_and_masks, ) - ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs accelerator.wait_for_everyone() @@ -235,12 +222,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Remaining methods maintain similar structure to flux implementation # with Lumina-specific model calls and strategies - def get_noise_scheduler( - self, args: argparse.Namespace, device: torch.device - ) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler( - num_train_timesteps=1000, shift=args.discrete_flow_shift - ) + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -258,26 +241,45 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): noise_scheduler, latents, batch, - text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) + text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) dit: lumina_models.NextDiT, network, weight_dtype, train_unet, is_train=True, ): + assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) noise = torch.randn_like(latents) bsz = latents.shape[0] - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = ( - flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = lumina_train_util.compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) - # May not need to pack/unpack? - # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 - # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents` + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -289,48 +291,35 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Unpack Gemma2 outputs gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds - def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): + def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = dit( x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask.to( - dtype=torch.int32 - ), # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask ) return model_pred model_pred = call_dit( img=noisy_model_input, gemma2_hidden_states=gemma2_hidden_states, - timesteps=timesteps, gemma2_attn_mask=gemma2_attn_mask, + timesteps=timesteps, ) - # May not need to pack/unpack? - # unpack latents - # model_pred = lumina_util.unpack_latents( - # model_pred, packed_latent_height, packed_latent_width - # ) - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type( - args, model_pred, noisy_model_input, sigmas - ) + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - # flow matching loss: this is different from SD3 - target = noise - latents + # flow matching loss + target = latents - noise # differential output preservation if "custom_attributes" in batch: diff_output_pr_indices = [] for i, custom_attributes in enumerate(batch["custom_attributes"]): - if ( - "diff_output_preservation" in custom_attributes - and custom_attributes["diff_output_preservation"] - ): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: diff_output_pr_indices.append(i) if len(diff_output_pr_indices) > 0: @@ -338,9 +327,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): with torch.no_grad(): model_pred_prior = call_dit( img=noisy_model_input[diff_output_pr_indices], - gemma2_hidden_states=gemma2_hidden_states[ - diff_output_pr_indices - ], + gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices], timesteps=timesteps[diff_output_pr_indices], gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), ) @@ -363,9 +350,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec( - None, args, False, True, False, lumina="lumina2" - ) + return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2") def update_metadata(self, metadata, args): metadata["ss_weighting_scheme"] = args.weighting_scheme @@ -384,12 +369,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): text_encoder.embed_tokens.requires_grad_(True) - def prepare_text_encoder_fp8( - self, index, text_encoder, te_weight_dtype, weight_dtype - ): - logger.info( - f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" - ) + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") text_encoder.to(te_weight_dtype) # fp8 text_encoder.embed_tokens.to(dtype=weight_dtype) @@ -402,12 +383,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # if we doesn't swap blocks, we can move the model to device nextdit = unet assert isinstance(nextdit, lumina_models.NextDiT) - nextdit = accelerator.prepare( - nextdit, device_placement=[not self.is_swapping_blocks] - ) - accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( - accelerator.device - ) # reduce peak memory usage + nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() return nextdit diff --git a/train_network.py b/train_network.py index 674f1cb6..2cf11af7 100644 --- a/train_network.py +++ b/train_network.py @@ -129,7 +129,7 @@ class NetworkTrainer: if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) - def load_target_model(self, args, weight_dtype, accelerator): + def load_target_model(self, args, weight_dtype, accelerator) -> tuple: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む @@ -354,12 +354,13 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders),