Refactor sigmas and timesteps

This commit is contained in:
rockerBOO
2025-03-20 14:32:56 -04:00
parent f974c6b257
commit 16cef81aea

View File

@@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma return sigma
@@ -413,32 +411,30 @@ def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape bsz, _, h, w = latents.shape
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling # Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid": if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main # https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else: else:
t = torch.rand((bsz,), device=device) sigmas = torch.rand((bsz,), device=device)
sigmas = t.view(-1, 1, 1, 1) timesteps = sigmas * num_timesteps
timesteps = t * 1000.0
elif args.timestep_sampling == "shift": elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device) sigmas = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid() sigmas = sigmas.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
sigmas = timesteps.view(-1, 1, 1, 1) timesteps = sigmas * num_timesteps
timesteps = timesteps * 1000.0
elif args.timestep_sampling == "flux_shift": elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device) sigmas = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid() sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
timesteps = time_shift(mu, 1.0, timesteps) sigmas = time_shift(mu, 1.0, sigmas)
sigmas = timesteps.view(-1, 1, 1, 1) timesteps = sigmas * num_timesteps
timesteps = timesteps * 1000.0
else: else:
# Sample a random timestep for each image # Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly # for weighting schemes where we sample timesteps non-uniformly
@@ -449,10 +445,13 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std, logit_std=args.logit_std,
mode_scale=args.mode_scale, mode_scale=args.mode_scale,
) )
indices = (u * noise_scheduler.config.num_train_timesteps).long() indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device) timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
if args.ip_noise_gamma: if args.ip_noise_gamma: