Fix IP noise calculation

This commit is contained in:
rockerBOO
2025-03-19 13:50:04 -04:00
parent 1eddac26b0
commit 5d5a7d2acf

View File

@@ -423,29 +423,24 @@ def get_noisy_model_input_and_timesteps(
else:
t = torch.rand((bsz,), device=device)
sigmas = t.view(-1, 1, 1, 1)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
sigmas = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
sigmas = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -458,10 +453,7 @@ def get_noisy_model_input_and_timesteps(
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -471,7 +463,9 @@ def get_noisy_model_input_and_timesteps(
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input += ip_noise_gamma * xi
noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents
else:
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas