From a9aa707b8473c2b40ce582bfc882f923dc80f4a8 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:17:06 +0700 Subject: [PATCH 1/9] Fix timestep sampling in get_noisy_model_input_and_timesteps function for lumina image v2 and add new timestep Resolve the issue reported at https://github.com/kohya-ss/sd-scripts/issues/2201 and introduce a new timestep type called "lognorm". --- library/lumina_train_util.py | 57 +++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index d5d5db05..31b9a2da 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -808,7 +808,6 @@ def get_noisy_model_input_and_timesteps( ) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. - Args: args (argparse.Namespace): Arguments. noise_scheduler (noise_scheduler): Noise scheduler. @@ -816,39 +815,41 @@ def get_noisy_model_input_and_timesteps( noise (Tensor): Latent noise. device (torch.device): Device. dtype (torch.dtype): Data type - Return: Tuple[Tensor, Tensor, Tensor]: noisy model input - timesteps + timesteps (reversed for Lumina: t=0 noise, t=1 image) sigmas """ bsz, _, h, w = latents.shape sigmas = None - + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": - # https://github.com/XLabs-AI/x-flux/tree/main t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) - - timesteps = t * 1000.0 + + # Reverse for Lumina: t=0 is noise, t=1 is image + t_lumina = 1.0 - t + timesteps = t_lumina * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = ( - logits_norm * args.sigmoid_scale - ) # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + logits_norm = logits_norm * args.sigmoid_scale + t = logits_norm.sigmoid() + t = (t * shift) / (1 + (shift - 1) * t) + + # Reverse for Lumina: t=0 is noise, t=1 is image + t_lumina = 1.0 - t + timesteps = t_lumina * 1000.0 + t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) @@ -857,6 +858,15 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents + + elif args.timestep_sampling == "lognorm": + u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device) + t = torch.sigmoid(u) # maps to [0,1] + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents + else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -868,14 +878,19 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. - sigmas = get_sigmas( - noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype + timesteps_normal = noise_scheduler.timesteps[indices].to(device=device) + + # Reverse for Lumina convention + timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal + + # Calculate sigmas with normal timesteps, then reverse interpolation + sigmas_normal = get_sigmas( + noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype ) + # Reverse sigma interpolation for Lumina + sigmas = 1.0 - sigmas_normal noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise - + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From b32d66cfd273aa1a458674dcf5ebef46b5fbb6df Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:47:32 +0700 Subject: [PATCH 2/9] Update lumina_train_network.py --- lumina_train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumina_train_network.py b/lumina_train_network.py index b08e3143..095bca24 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = dit( x=img, # image latents (B, C, H, W) - t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask ) From b869b5d95c25beee75ad5de00200abf75c73e6a0 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:33:41 +0700 Subject: [PATCH 3/9] Update lumina_train_util.py Change the apply_model_prediction_type function to suitable new call_dit --- library/lumina_train_util.py | 72 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 31b9a2da..56b5c0b5 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator import torch from torch import Tensor +from torch.distributions import LogNormal from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm @@ -808,6 +809,7 @@ def get_noisy_model_input_and_timesteps( ) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. + Args: args (argparse.Namespace): Arguments. noise_scheduler (noise_scheduler): Noise scheduler. @@ -815,58 +817,54 @@ def get_noisy_model_input_and_timesteps( noise (Tensor): Latent noise. device (torch.device): Device. dtype (torch.dtype): Data type + Return: Tuple[Tensor, Tensor, Tensor]: noisy model input - timesteps (reversed for Lumina: t=0 noise, t=1 image) + timesteps sigmas """ bsz, _, h, w = latents.shape sigmas = None - + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) - - # Reverse for Lumina: t=0 is noise, t=1 is image - t_lumina = 1.0 - t - timesteps = t_lumina * 1000.0 + + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale - t = logits_norm.sigmoid() - t = (t * shift) / (1 + (shift - 1) * t) - - # Reverse for Lumina: t=0 is noise, t=1 is image - t_lumina = 1.0 - t - timesteps = t_lumina * 1000.0 - t = t.view(-1, 1, 1, 1) + logits_norm = ( + logits_norm * args.sigmoid_scale + ) # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "nextdit_shift": t = torch.rand((bsz,), device=device) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) t = time_shift(mu, 1.0, t) - timesteps = t * 1000.0 + timesteps = 1 - t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "lognorm": - u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device) - t = torch.sigmoid(u) # maps to [0,1] + lognormal = LogNormal(loc=0, scale=0.333) + t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device) - timesteps = t * 1000.0 + t = ((1 - t/t.max()) * 1000) t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -878,19 +876,14 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps_normal = noise_scheduler.timesteps[indices].to(device=device) - - # Reverse for Lumina convention - timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal - - # Calculate sigmas with normal timesteps, then reverse interpolation - sigmas_normal = get_sigmas( - noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas( + noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype ) - # Reverse sigma interpolation for Lumina - sigmas = 1.0 - sigmas_normal noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise - + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -1064,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"], default="shift", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", @@ -1075,6 +1068,13 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): default=1.0, help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', ) + + parser.add_argument( + "--lognorm_alpha", + type=float, + default=0.75, + help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心/早期へのタイムステップ分配のアルファ係数(timestep-samplingが"lognorm"の場合のみ有効)。', + ) parser.add_argument( "--model_prediction_type", choices=["raw", "additive", "sigma_scaled"], From 8ad9172162e1c232d4b6640a6302bd04f86d9ced Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:40:49 +0700 Subject: [PATCH 4/9] Update lumina_train.py --- lumina_train.py | 112 ++++++++++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 52 deletions(-) diff --git a/lumina_train.py b/lumina_train.py index ca60c658..7680c3fa 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -361,65 +361,73 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.blockwise_fused_optimizers: - # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. - # This balances memory usage and management complexity. + # if args.blockwise_fused_optimizers: + # # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # # This balances memory usage and management complexity. - # split params into groups. currently different learning rates are not supported - grouped_params = [] - param_group = {} - for group in params_to_optimize: - named_parameters = list(nextdit.named_parameters()) - assert len(named_parameters) == len( - group["params"] - ), "number of parameters does not match" - for p, np in zip(group["params"], named_parameters): - # determine target layer and block index for each parameter - block_type = "other" # double, single or other - if np[0].startswith("double_blocks"): - block_index = int(np[0].split(".")[1]) - block_type = "double" - elif np[0].startswith("single_blocks"): - block_index = int(np[0].split(".")[1]) - block_type = "single" - else: - block_index = -1 + # # split params into groups. currently different learning rates are not supported + # grouped_params = [] + # param_group = {} + # for group in params_to_optimize: + # named_parameters = list(nextdit.named_parameters()) + # assert len(named_parameters) == len( + # group["params"] + # ), "number of parameters does not match" + # for p, np in zip(group["params"], named_parameters): + # # determine target layer and block index for each parameter + # block_type = "other" # double, single or other + # if np[0].startswith("double_blocks"): + # block_index = int(np[0].split(".")[1]) + # block_type = "double" + # elif np[0].startswith("single_blocks"): + # block_index = int(np[0].split(".")[1]) + # block_type = "single" + # else: + # block_index = -1 - param_group_key = (block_type, block_index) - if param_group_key not in param_group: - param_group[param_group_key] = [] - param_group[param_group_key].append(p) + # param_group_key = (block_type, block_index) + # if param_group_key not in param_group: + # param_group[param_group_key] = [] + # param_group[param_group_key].append(p) - block_types_and_indices = [] - for param_group_key, param_group in param_group.items(): - block_types_and_indices.append(param_group_key) - grouped_params.append({"params": param_group, "lr": args.learning_rate}) + # block_types_and_indices = [] + # for param_group_key, param_group in param_group.items(): + # block_types_and_indices.append(param_group_key) + # grouped_params.append({"params": param_group, "lr": args.learning_rate}) - num_params = 0 - for p in param_group: - num_params += p.numel() - accelerator.print(f"block {param_group_key}: {num_params} parameters") + # num_params = 0 + # for p in param_group: + # num_params += p.numel() + # accelerator.print(f"block {param_group_key}: {num_params} parameters") - # prepare optimizers for each group - optimizers = [] - for group in grouped_params: - _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) - optimizers.append(optimizer) - optimizer = optimizers[0] # avoid error in the following code + # # prepare optimizers for each group + # optimizers = [] + # for group in grouped_params: + # _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + # optimizers.append(optimizer) + # optimizer = optimizers[0] # avoid error in the following code - logger.info( - f"using {len(optimizers)} optimizers for blockwise fused optimizers" - ) + # logger.info( + # f"using {len(optimizers)} optimizers for blockwise fused optimizers" + # ) - if train_util.is_schedulefree_optimizer(optimizers[0], args): - raise ValueError( - "Schedule-free optimizer is not supported with blockwise fused optimizers" - ) - optimizer_train_fn = lambda: None # dummy function - optimizer_eval_fn = lambda: None # dummy function - else: - _, _, optimizer = train_util.get_optimizer( + # if train_util.is_schedulefree_optimizer(optimizers[0], args): + # raise ValueError( + # "Schedule-free optimizer is not supported with blockwise fused optimizers" + # ) + # optimizer_train_fn = lambda: None # dummy function + # optimizer_eval_fn = lambda: None # dummy function + # else: + # _, _, optimizer = train_util.get_optimizer( + # args, trainable_params=params_to_optimize + # ) + # optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( + # optimizer, args + # ) + + #Currently when using blockwise_fused_optimizers the weight of model is not updated. + _, _, optimizer = train_util.get_optimizer( args, trainable_params=params_to_optimize ) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( From e222084e0353c21ee5d59254e2f68155e7687086 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:41:24 +0700 Subject: [PATCH 5/9] Update lumina_train.py --- lumina_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lumina_train.py b/lumina_train.py index 7680c3fa..44b8cafb 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -430,9 +430,9 @@ def train(args): _, _, optimizer = train_util.get_optimizer( args, trainable_params=params_to_optimize ) - optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( - optimizer, args - ) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( + optimizer, args + ) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset From fe7005caaad56eca830e64822b9f32e010ad98fc Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:49:24 +0700 Subject: [PATCH 6/9] Update prepare_accelerator to handle got an error when training on multigpu --- library/train_util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 756d88b1..301ad771 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5525,6 +5525,9 @@ def prepare_accelerator(args: argparse.Namespace): if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None ), + ( + DistributedDataParallelKwargs(find_unused_parameters=True) + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) From a5f38044051ec9d2b799cdd3999860735a06ffdb Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 20:53:01 +0700 Subject: [PATCH 7/9] Update lumina_train_util.py --- library/lumina_train_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 56b5c0b5..360039b8 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -1057,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"], + choices=["sigma", "uniform", "lognorm", "sigmoid", "shift", "nextdit_shift"], default="shift", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + help="Method to sample timesteps: sigma-based, uniform random, lognorm, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、lognorm、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", From 717502bd639297292bc3dd79996c760c17eaf666 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:12:23 +0700 Subject: [PATCH 8/9] Update lumina_train_util.py --- library/lumina_train_util.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 360039b8..32ac8b03 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -8,7 +8,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator import torch from torch import Tensor -from torch.distributions import LogNormal from accelerate import Accelerator, PartialState from transformers import Gemma2Model from tqdm import tqdm @@ -858,13 +857,6 @@ def get_noisy_model_input_and_timesteps( timesteps = 1 - t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * noise + t * latents - elif args.timestep_sampling == "lognorm": - lognormal = LogNormal(loc=0, scale=0.333) - t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device) - - t = ((1 - t/t.max()) * 1000) - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * noise + t * latents else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -1057,10 +1049,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "lognorm", "sigmoid", "shift", "nextdit_shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], default="shift", - help="Method to sample timesteps: sigma-based, uniform random, lognorm, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." - " / タイムステップをサンプリングする方法:sigma、random uniform、lognorm、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", ) parser.add_argument( "--sigmoid_scale", @@ -1069,12 +1061,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', ) - parser.add_argument( - "--lognorm_alpha", - type=float, - default=0.75, - help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心/早期へのタイムステップ分配のアルファ係数(timestep-samplingが"lognorm"の場合のみ有効)。', - ) parser.add_argument( "--model_prediction_type", choices=["raw", "additive", "sigma_scaled"], From 4d24b71c1647f674951f482857c12c74a5a46440 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 29 Sep 2025 23:48:15 +0700 Subject: [PATCH 9/9] Update lumina_train to fix reversed timestep --- lumina_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumina_train.py b/lumina_train.py index 44b8cafb..88bb88bb 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -751,7 +751,7 @@ def train(args): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = nextdit( x=noisy_model_input, # image latents (B, C, H, W) - t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to( dtype=torch.int32