mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge f7fc7ddda2 into a5a162044c
This commit is contained in:
@@ -475,11 +475,7 @@ def sample_image_inference(
|
|||||||
|
|
||||||
|
|
||||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||||
# the following implementation was original for t=0: clean / t=1: noise
|
|
||||||
# Since we adopt the reverse, the 1-t operations are needed
|
|
||||||
t = 1 - t
|
|
||||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
t = 1 - t
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
@@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
|
|||||||
weighting = torch.ones_like(sigmas)
|
weighting = torch.ones_like(sigmas)
|
||||||
return weighting
|
return weighting
|
||||||
|
|
||||||
|
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
|
||||||
def get_noisy_model_input_and_timesteps(
|
def get_noisy_model_input_and_timesteps(
|
||||||
args, noise_scheduler, latents, noise, device, dtype
|
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
|
||||||
Get noisy model input and timesteps.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args (argparse.Namespace): Arguments.
|
|
||||||
noise_scheduler (noise_scheduler): Noise scheduler.
|
|
||||||
latents (Tensor): Latents.
|
|
||||||
noise (Tensor): Latent noise.
|
|
||||||
device (torch.device): Device.
|
|
||||||
dtype (torch.dtype): Data type
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Tuple[Tensor, Tensor, Tensor]:
|
|
||||||
noisy model input
|
|
||||||
timesteps
|
|
||||||
sigmas
|
|
||||||
"""
|
|
||||||
bsz, _, h, w = latents.shape
|
bsz, _, h, w = latents.shape
|
||||||
sigmas = None
|
assert bsz > 0, "Batch size not large enough"
|
||||||
|
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||||
# Simple random t-based noise sampling
|
# Simple random sigma-based noise sampling
|
||||||
if args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "sigmoid":
|
||||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||||
else:
|
else:
|
||||||
t = torch.rand((bsz,), device=device)
|
sigmas = torch.rand((bsz,), device=device)
|
||||||
|
|
||||||
timesteps = t * 1000.0
|
timesteps = sigmas * num_timesteps
|
||||||
t = t.view(-1, 1, 1, 1)
|
|
||||||
noisy_model_input = (1 - t) * noise + t * latents
|
|
||||||
elif args.timestep_sampling == "shift":
|
elif args.timestep_sampling == "shift":
|
||||||
shift = args.discrete_flow_shift
|
shift = args.discrete_flow_shift
|
||||||
logits_norm = torch.randn(bsz, device=device)
|
sigmas = torch.randn(bsz, device=device)
|
||||||
logits_norm = (
|
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
logits_norm * args.sigmoid_scale
|
sigmas = sigmas.sigmoid()
|
||||||
) # larger scale for more uniform sampling
|
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||||
timesteps = logits_norm.sigmoid()
|
timesteps = sigmas * num_timesteps
|
||||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
|
||||||
|
|
||||||
t = timesteps.view(-1, 1, 1, 1)
|
|
||||||
timesteps = timesteps * 1000.0
|
|
||||||
noisy_model_input = (1 - t) * noise + t * latents
|
|
||||||
elif args.timestep_sampling == "nextdit_shift":
|
elif args.timestep_sampling == "nextdit_shift":
|
||||||
t = torch.rand((bsz,), device=device)
|
sigmas = torch.rand((bsz,), device=device)
|
||||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||||
t = time_shift(mu, 1.0, t)
|
sigmas = time_shift(mu, 1.0, sigmas)
|
||||||
|
|
||||||
timesteps = t * 1000.0
|
timesteps = sigmas * num_timesteps
|
||||||
t = t.view(-1, 1, 1, 1)
|
elif args.timestep_sampling == "flux_shift":
|
||||||
noisy_model_input = (1 - t) * noise + t * latents
|
sigmas = torch.randn(bsz, device=device)
|
||||||
|
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
|
sigmas = sigmas.sigmoid()
|
||||||
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||||
|
sigmas = time_shift(mu, 1.0, sigmas)
|
||||||
|
timesteps = sigmas * num_timesteps
|
||||||
else:
|
else:
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
# for weighting schemes where we sample timesteps non-uniformly
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
@@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
logit_std=args.logit_std,
|
logit_std=args.logit_std,
|
||||||
mode_scale=args.mode_scale,
|
mode_scale=args.mode_scale,
|
||||||
)
|
)
|
||||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
indices = (u * num_timesteps).long()
|
||||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||||
|
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||||
|
|
||||||
# Add noise according to flow matching.
|
# Broadcast sigmas to latent shape
|
||||||
sigmas = get_sigmas(
|
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||||
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
|
|
||||||
)
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
|
# (this is the forward diffusion process)
|
||||||
|
if args.ip_noise_gamma:
|
||||||
|
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||||
|
if args.ip_noise_gamma_random_strength:
|
||||||
|
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||||
|
else:
|
||||||
|
ip_noise_gamma = args.ip_noise_gamma
|
||||||
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||||
|
else:
|
||||||
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||||
|
|
||||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||||
|
|
||||||
@@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--timestep_sampling",
|
"--timestep_sampling",
|
||||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
|
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
|
||||||
default="shift",
|
default="shift",
|
||||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
|
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
|
||||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
|
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sigmoid_scale",
|
"--sigmoid_scale",
|
||||||
|
|||||||
@@ -743,7 +743,7 @@ def train(args):
|
|||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = nextdit(
|
model_pred = nextdit(
|
||||||
x=noisy_model_input, # image latents (B, C, H, W)
|
x=noisy_model_input, # image latents (B, C, H, W)
|
||||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||||
cap_mask=gemma2_attn_mask.to(
|
cap_mask=gemma2_attn_mask.to(
|
||||||
dtype=torch.int32
|
dtype=torch.int32
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||||
model_pred = dit(
|
model_pred = dit(
|
||||||
x=img, # image latents (B, C, H, W)
|
x=img, # image latents (B, C, H, W)
|
||||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user