Update lumina_train_util.py

Change the apply_model_prediction_type function to suitable new call_dit
This commit is contained in:
duongve13112002
2025-09-29 20:33:41 +07:00
committed by GitHub
parent a9aa707b84
commit b869b5d95c

View File

@@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
import torch
from torch import Tensor
from torch.distributions import LogNormal
from accelerate import Accelerator, PartialState
from transformers import Gemma2Model
from tqdm import tqdm
@@ -808,6 +809,7 @@ def get_noisy_model_input_and_timesteps(
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
@@ -815,58 +817,54 @@ def get_noisy_model_input_and_timesteps(
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps (reversed for Lumina: t=0 noise, t=1 image)
timesteps
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale
t = logits_norm.sigmoid()
t = (t * shift) / (1 + (shift - 1) * t)
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
timesteps = t * 1000.0
timesteps = 1 - t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "lognorm":
u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device)
t = torch.sigmoid(u) # maps to [0,1]
lognormal = LogNormal(loc=0, scale=0.333)
t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device)
timesteps = t * 1000.0
t = ((1 - t/t.max()) * 1000)
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -878,19 +876,14 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps_normal = noise_scheduler.timesteps[indices].to(device=device)
# Reverse for Lumina convention
timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal
# Calculate sigmas with normal timesteps, then reverse interpolation
sigmas_normal = get_sigmas(
noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
# Reverse sigma interpolation for Lumina
sigmas = 1.0 - sigmas_normal
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -1064,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",
@@ -1075,6 +1068,13 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--lognorm_alpha",
type=float,
default=0.75,
help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心早期へのタイムステップ分配のアルファ係数timestep-samplingが"lognorm"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],