mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Merge pull request #2 from duongve13112002/fix_lumina_image_v2_reversed_timesteps
Fix lumina image v2 reversed timesteps
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions import LogNormal
|
||||
from accelerate import Accelerator, PartialState
|
||||
from transformers import Gemma2Model
|
||||
from tqdm import tqdm
|
||||
@@ -808,6 +809,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Arguments.
|
||||
noise_scheduler (noise_scheduler): Noise scheduler.
|
||||
@@ -815,58 +817,54 @@ def get_noisy_model_input_and_timesteps(
|
||||
noise (Tensor): Latent noise.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): Data type
|
||||
|
||||
Return:
|
||||
Tuple[Tensor, Tensor, Tensor]:
|
||||
noisy model input
|
||||
timesteps (reversed for Lumina: t=0 noise, t=1 image)
|
||||
timesteps
|
||||
sigmas
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
|
||||
# Reverse for Lumina: t=0 is noise, t=1 is image
|
||||
t_lumina = 1.0 - t
|
||||
timesteps = t_lumina * 1000.0
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale
|
||||
t = logits_norm.sigmoid()
|
||||
t = (t * shift) / (1 + (shift - 1) * t)
|
||||
|
||||
# Reverse for Lumina: t=0 is noise, t=1 is image
|
||||
t_lumina = 1.0 - t
|
||||
timesteps = t_lumina * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
logits_norm = (
|
||||
logits_norm * args.sigmoid_scale
|
||||
) # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
t = torch.rand((bsz,), device=device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
t = time_shift(mu, 1.0, t)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
timesteps = 1 - t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
|
||||
elif args.timestep_sampling == "lognorm":
|
||||
u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device)
|
||||
t = torch.sigmoid(u) # maps to [0,1]
|
||||
lognormal = LogNormal(loc=0, scale=0.333)
|
||||
t = lognormal.sample((int(timesteps * args.lognorm_alpha),)).to(device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = ((1 - t/t.max()) * 1000)
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -878,19 +876,14 @@ def get_noisy_model_input_and_timesteps(
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
timesteps_normal = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Reverse for Lumina convention
|
||||
timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal
|
||||
|
||||
# Calculate sigmas with normal timesteps, then reverse interpolation
|
||||
sigmas_normal = get_sigmas(
|
||||
noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(
|
||||
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
|
||||
)
|
||||
# Reverse sigma interpolation for Lumina
|
||||
sigmas = 1.0 - sigmas_normal
|
||||
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
|
||||
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
@@ -1064,10 +1057,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "lognorm", "nextdit_shift"],
|
||||
default="shift",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, lognorm, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid, lognorm、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
@@ -1075,6 +1068,13 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=1.0,
|
||||
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lognorm_alpha",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='Alpha factor for distribute timestep to the center/early (only used when timestep-sampling is "lognorm"). / 中心/早期へのタイムステップ分配のアルファ係数(timestep-samplingが"lognorm"の場合のみ有効)。',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_prediction_type",
|
||||
choices=["raw", "additive", "sigma_scaled"],
|
||||
|
||||
118
lumina_train.py
118
lumina_train.py
@@ -361,70 +361,78 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
# if args.blockwise_fused_optimizers:
|
||||
# # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# # This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
# # split params into groups. currently different learning rates are not supported
|
||||
# grouped_params = []
|
||||
# param_group = {}
|
||||
# for group in params_to_optimize:
|
||||
# named_parameters = list(nextdit.named_parameters())
|
||||
# assert len(named_parameters) == len(
|
||||
# group["params"]
|
||||
# ), "number of parameters does not match"
|
||||
# for p, np in zip(group["params"], named_parameters):
|
||||
# # determine target layer and block index for each parameter
|
||||
# block_type = "other" # double, single or other
|
||||
# if np[0].startswith("double_blocks"):
|
||||
# block_index = int(np[0].split(".")[1])
|
||||
# block_type = "double"
|
||||
# elif np[0].startswith("single_blocks"):
|
||||
# block_index = int(np[0].split(".")[1])
|
||||
# block_type = "single"
|
||||
# else:
|
||||
# block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
# param_group_key = (block_type, block_index)
|
||||
# if param_group_key not in param_group:
|
||||
# param_group[param_group_key] = []
|
||||
# param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
# block_types_and_indices = []
|
||||
# for param_group_key, param_group in param_group.items():
|
||||
# block_types_and_indices.append(param_group_key)
|
||||
# grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
# num_params = 0
|
||||
# for p in param_group:
|
||||
# num_params += p.numel()
|
||||
# accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
# # prepare optimizers for each group
|
||||
# optimizers = []
|
||||
# for group in grouped_params:
|
||||
# _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
# optimizers.append(optimizer)
|
||||
# optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(
|
||||
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
)
|
||||
# logger.info(
|
||||
# f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
# )
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError(
|
||||
"Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
)
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
# if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
# raise ValueError(
|
||||
# "Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
# )
|
||||
# optimizer_train_fn = lambda: None # dummy function
|
||||
# optimizer_eval_fn = lambda: None # dummy function
|
||||
# else:
|
||||
# _, _, optimizer = train_util.get_optimizer(
|
||||
# args, trainable_params=params_to_optimize
|
||||
# )
|
||||
# optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
# optimizer, args
|
||||
# )
|
||||
|
||||
#Currently when using blockwise_fused_optimizers the weight of model is not updated.
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params=params_to_optimize
|
||||
)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
optimizer, args
|
||||
)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
optimizer, args
|
||||
)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
|
||||
Reference in New Issue
Block a user