mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Move get_huber_threshold_if_needed
This commit is contained in:
@@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
b_size = timesteps.shape[0]
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
result = torch.exp(-alpha * timesteps) * args.huber_scale
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
result = result.to(timesteps.device)
|
||||
elif args.huber_schedule == "constant":
|
||||
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
|
||||
@@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
b_size = timesteps.shape[0]
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
result = torch.exp(-alpha * timesteps) * args.huber_scale
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
result = result.to(timesteps.device)
|
||||
elif args.huber_schedule == "constant":
|
||||
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Add noise to the latents according to the noise magnitude at each timestep
|
||||
|
||||
Reference in New Issue
Block a user