make timestep sampling behave in the standard way when huber loss is used

This commit is contained in:
recris
2024-09-21 12:58:32 +01:00
parent 0b7927e50b
commit e1f23af1bc

View File

@@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common(
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu')
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()
if args.huber_schedule == "exponential": if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep) huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr": elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant": elif args.huber_schedule == "constant":
huber_c = args.huber_c huber_c = torch.full((b_size,), args.huber_c)
else: else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
huber_c = huber_c.to(device)
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == "l2": elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) huber_c = None # may be anything, as it's not used
huber_c = 1 # may be anything, as it's not used
else: else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}") raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()
timesteps = timesteps.long().to(device)
return timesteps, huber_c return timesteps, huber_c
@@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
return noise, noisy_latents, timesteps, huber_c return noise, noisy_latents, timesteps, huber_c
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss( def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
): ):
if loss_type == "l2": if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber": elif loss_type == "huber":
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean": if reduction == "mean":
loss = torch.mean(loss) loss = torch.mean(loss)
elif reduction == "sum": elif reduction == "sum":
loss = torch.sum(loss) loss = torch.sum(loss)
elif loss_type == "smooth_l1": elif loss_type == "smooth_l1":
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean": if reduction == "mean":
loss = torch.mean(loss) loss = torch.mean(loss)