mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add perturbation on noisy_model_input if needed
This commit is contained in:
@@ -410,20 +410,11 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents)
|
||||
else:
|
||||
noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents)
|
||||
else:
|
||||
noise = input_noise
|
||||
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -474,6 +465,15 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents)
|
||||
else:
|
||||
xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents)
|
||||
noisy_model_input += xi
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user