mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor sigmas and timesteps
This commit is contained in:
@@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
|
|||||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
while len(sigma.shape) < n_dim:
|
|
||||||
sigma = sigma.unsqueeze(-1)
|
|
||||||
return sigma
|
return sigma
|
||||||
|
|
||||||
|
|
||||||
@@ -413,32 +411,30 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
bsz, _, h, w = latents.shape
|
bsz, _, h, w = latents.shape
|
||||||
|
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||||
# Simple random t-based noise sampling
|
# Simple random sigma-based noise sampling
|
||||||
if args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "sigmoid":
|
||||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||||
else:
|
else:
|
||||||
t = torch.rand((bsz,), device=device)
|
sigmas = torch.rand((bsz,), device=device)
|
||||||
|
|
||||||
sigmas = t.view(-1, 1, 1, 1)
|
timesteps = sigmas * num_timesteps
|
||||||
timesteps = t * 1000.0
|
|
||||||
elif args.timestep_sampling == "shift":
|
elif args.timestep_sampling == "shift":
|
||||||
shift = args.discrete_flow_shift
|
shift = args.discrete_flow_shift
|
||||||
logits_norm = torch.randn(bsz, device=device)
|
sigmas = torch.randn(bsz, device=device)
|
||||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
timesteps = logits_norm.sigmoid()
|
sigmas = sigmas.sigmoid()
|
||||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||||
sigmas = timesteps.view(-1, 1, 1, 1)
|
timesteps = sigmas * num_timesteps
|
||||||
timesteps = timesteps * 1000.0
|
|
||||||
elif args.timestep_sampling == "flux_shift":
|
elif args.timestep_sampling == "flux_shift":
|
||||||
logits_norm = torch.randn(bsz, device=device)
|
sigmas = torch.randn(bsz, device=device)
|
||||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
timesteps = logits_norm.sigmoid()
|
sigmas = sigmas.sigmoid()
|
||||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||||
timesteps = time_shift(mu, 1.0, timesteps)
|
sigmas = time_shift(mu, 1.0, sigmas)
|
||||||
sigmas = timesteps.view(-1, 1, 1, 1)
|
timesteps = sigmas * num_timesteps
|
||||||
timesteps = timesteps * 1000.0
|
|
||||||
else:
|
else:
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
# for weighting schemes where we sample timesteps non-uniformly
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
@@ -449,10 +445,13 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
logit_std=args.logit_std,
|
logit_std=args.logit_std,
|
||||||
mode_scale=args.mode_scale,
|
mode_scale=args.mode_scale,
|
||||||
)
|
)
|
||||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
indices = (u * num_timesteps).long()
|
||||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||||
|
|
||||||
|
# Broadcast sigmas to latent shape
|
||||||
|
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
if args.ip_noise_gamma:
|
if args.ip_noise_gamma:
|
||||||
|
|||||||
Reference in New Issue
Block a user