Fix IP noise gamma to use random values

This commit is contained in:
rockerBOO
2025-03-18 18:42:35 -04:00
parent c8be141ae0
commit b425466e7b

View File

@@ -415,15 +415,15 @@ def get_noisy_model_input_and_timesteps(
bsz, _, h, w = latents.shape
sigmas = None
ip_noise_gamma = 0.0
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma
ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents)
else:
ip_noise_gamma = args.ip_noise_gamma
ip_noise = args.ip_noise_gamma * torch.randn_like(latents)
else:
ip_noise = torch.zeros_like(latents)
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
@@ -435,7 +435,7 @@ def get_noisy_model_input_and_timesteps(
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
noisy_model_input = (1 - t) * latents + t * (noise + ip_noise)
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
@@ -445,7 +445,7 @@ def get_noisy_model_input_and_timesteps(
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
noisy_model_input = (1 - t) * latents + t * (noise + ip_noise)
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
@@ -455,7 +455,7 @@ def get_noisy_model_input_and_timesteps(
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
noisy_model_input = (1 - t) * latents + t * (noise + ip_noise)
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -471,7 +471,7 @@ def get_noisy_model_input_and_timesteps(
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + ip_noise_gamma + (1.0 - sigmas) * latents
noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas