Merge pull request #2 from duongve13112002/fix_lumina_image_v2_reversed_timesteps

Fix lumina image v2 reversed timesteps
This commit is contained in:
duongve13112002
2025-09-29 20:43:56 +07:00
committed by GitHub
2 changed files with 99 additions and 91 deletions

View File

@@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
import torch
from torch import Tensor
from torch.distributions import LogNormal
from accelerate import Accelerator, PartialState
from transformers import Gemma2Model
from tqdm import tqdm
@@ -808,6 +809,7 @@ def get_noisy_model_input_and_timesteps(
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
@@ -815,58 +817,54 @@ def get_noisy_model_input_and_timesteps(
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps (reversed for Lumina: t=0 noise, t=1 image)
timesteps
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale
t = logits_norm.sigmoid()
t = (t * shift) / (1 + (shift - 1) * t)
# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
timesteps = t * 1000.0
timesteps = 1 - t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "lognorm":
u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device)
t = torch.sigmoid(u) # maps to [0,1]
lognormal = LogNormal(loc=0, scale=0.333)
t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device)
timesteps = t * 1000.0
t = ((1 - t/t.max()) * 1000)
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -878,19 +876,14 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps_normal = noise_scheduler.timesteps[indices].to(device=device)
# Reverse for Lumina convention
timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal
# Calculate sigmas with normal timesteps, then reverse interpolation
sigmas_normal = get_sigmas(
noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
# Reverse sigma interpolation for Lumina
sigmas = 1.0 - sigmas_normal
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -1064,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",
@@ -1075,6 +1068,13 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--lognorm_alpha",
type=float,
default=0.75,
help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心早期へのタイムステップ分配のアルファ係数timestep-samplingが"lognorm"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],

View File

@@ -361,70 +361,78 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# if args.blockwise_fused_optimizers:
# # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# # This balances memory usage and management complexity.
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters())
assert len(named_parameters) == len(
group["params"]
), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1
# # split params into groups. currently different learning rates are not supported
# grouped_params = []
# param_group = {}
# for group in params_to_optimize:
# named_parameters = list(nextdit.named_parameters())
# assert len(named_parameters) == len(
# group["params"]
# ), "number of parameters does not match"
# for p, np in zip(group["params"], named_parameters):
# # determine target layer and block index for each parameter
# block_type = "other" # double, single or other
# if np[0].startswith("double_blocks"):
# block_index = int(np[0].split(".")[1])
# block_type = "double"
# elif np[0].startswith("single_blocks"):
# block_index = int(np[0].split(".")[1])
# block_type = "single"
# else:
# block_index = -1
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
# param_group_key = (block_type, block_index)
# if param_group_key not in param_group:
# param_group[param_group_key] = []
# param_group[param_group_key].append(p)
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
# block_types_and_indices = []
# for param_group_key, param_group in param_group.items():
# block_types_and_indices.append(param_group_key)
# grouped_params.append({"params": param_group, "lr": args.learning_rate})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# num_params = 0
# for p in param_group:
# num_params += p.numel()
# accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
# # prepare optimizers for each group
# optimizers = []
# for group in grouped_params:
# _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
# optimizers.append(optimizer)
# optimizer = optimizers[0] # avoid error in the following code
logger.info(
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
)
# logger.info(
# f"using {len(optimizers)} optimizers for blockwise fused optimizers"
# )
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError(
"Schedule-free optimizer is not supported with blockwise fused optimizers"
)
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(
# if train_util.is_schedulefree_optimizer(optimizers[0], args):
# raise ValueError(
# "Schedule-free optimizer is not supported with blockwise fused optimizers"
# )
# optimizer_train_fn = lambda: None # dummy function
# optimizer_eval_fn = lambda: None # dummy function
# else:
# _, _, optimizer = train_util.get_optimizer(
# args, trainable_params=params_to_optimize
# )
# optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
# optimizer, args
# )
#Currently when using blockwise_fused_optimizers the weight of model is not updated.
_, _, optimizer = train_util.get_optimizer(
args, trainable_params=params_to_optimize
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset