mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change order to match upstream
This commit is contained in:
@@ -413,8 +413,6 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
bsz, _, h, w = latents.shape
|
bsz, _, h, w = latents.shape
|
||||||
sigmas = None
|
|
||||||
|
|
||||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||||
# Simple random t-based noise sampling
|
# Simple random t-based noise sampling
|
||||||
if args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "sigmoid":
|
||||||
@@ -463,9 +461,9 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
||||||
else:
|
else:
|
||||||
ip_noise_gamma = args.ip_noise_gamma
|
ip_noise_gamma = args.ip_noise_gamma
|
||||||
noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||||
else:
|
else:
|
||||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||||
|
|
||||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user