From b869b5d95c25beee75ad5de00200abf75c73e6a0 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:33:41 +0700 Subject: [PATCH] Update lumina_train_util.py Change the apply_model_prediction_type function to suitable new call_dit --- library/lumina_train_util.py | 72 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 31b9a2da..56b5c0b5 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator import torch from torch import Tensor +from torch.distributions import LogNormal from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm @@ -808,6 +809,7 @@ def get_noisy_model_input_and_timesteps( ) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. + Args: args (argparse.Namespace): Arguments. noise_scheduler (noise_scheduler): Noise scheduler. @@ -815,58 +817,54 @@ def get_noisy_model_input_and_timesteps( noise (Tensor): Latent noise. device (torch.device): Device. dtype (torch.dtype): Data type + Return: Tuple[Tensor, Tensor, Tensor]: noisy model input - timesteps (reversed for Lumina: t=0 noise, t=1 image) + timesteps sigmas """ bsz, _, h, w = latents.shape sigmas = None - + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) - - # Reverse for Lumina: t=0 is noise, t=1 is image - t_lumina = 1.0 - t - timesteps = t_lumina * 1000.0 + + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale - t = logits_norm.sigmoid() - t = (t * shift) / (1 + (shift - 1) * t) - - # Reverse for Lumina: t=0 is noise, t=1 is image - t_lumina = 1.0 - t - timesteps = t_lumina * 1000.0 - t = t.view(-1, 1, 1, 1) + logits_norm = ( + logits_norm * args.sigmoid_scale + ) # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) t = time_shift(mu, 1.0, t) - timesteps = t * 1000.0 + timesteps = 1 - t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "lognorm": - u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device) - t = torch.sigmoid(u) # maps to [0,1] + lognormal = LogNormal(loc=0, scale=0.333) + t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device) - timesteps = t * 1000.0 + t = ((1 - t/t.max()) * 1000) t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -878,19 +876,14 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps_normal = noise_scheduler.timesteps[indices].to(device=device) - - # Reverse for Lumina convention - timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal - - # Calculate sigmas with normal timesteps, then reverse interpolation - sigmas_normal = get_sigmas( - noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas( + noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype ) - # Reverse sigma interpolation for Lumina - sigmas = 1.0 - sigmas_normal noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise - + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -1064,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"], default="shift", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", @@ -1075,6 +1068,13 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): default=1.0, help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', ) + + parser.add_argument( + "--lognorm_alpha", + type=float, + default=0.75, + help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心/早期へのタイムステップ分配のアルファ係数(timestep-samplingが"lognorm"の場合のみ有効)。', + ) parser.add_argument( "--model_prediction_type", choices=["raw", "additive", "sigma_scaled"],