Change val loss calculate method

This commit is contained in:
Hina Chen
2024-12-27 17:28:05 +08:00
parent 05bb9183fa
commit 62164e5792

View File

@@ -1383,16 +1383,20 @@ class NetworkTrainer:
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
# huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
# loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
# if weighting is not None:
# loss = loss * weighting
# if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
# loss = apply_masked_loss(loss, batch)
# loss = loss.mean([1, 2, 3])
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし