mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Change val loss calculate method
This commit is contained in:
@@ -1383,16 +1383,20 @@ class NetworkTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
# loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
# if weighting is not None:
|
||||
# loss = loss * weighting
|
||||
# if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
# loss = apply_masked_loss(loss, batch)
|
||||
# loss = loss.mean([1, 2, 3])
|
||||
|
||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
||||
Reference in New Issue
Block a user