mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Change val loss calculate method
This commit is contained in:
@@ -1383,16 +1383,20 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
# huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
# loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||||
if weighting is not None:
|
# if weighting is not None:
|
||||||
loss = loss * weighting
|
# loss = loss * weighting
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
# if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
# loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
# loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user