Use args.wavelet_loss to activate

This commit is contained in:
rockerBOO
2025-04-12 04:10:48 -04:00
parent 19ce0ae61f
commit 40128b7dc0

View File

@@ -469,7 +469,7 @@ class NetworkTrainer:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
wav_loss = None
if args.wavelet_loss_alpha:
if args.wavelet_loss:
if args.wavelet_loss_rectified_flow:
# Calculate flow-based clean estimate using the target
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target