diff --git a/train_network.py b/train_network.py index 0f92e4f9..e2974c47 100644 --- a/train_network.py +++ b/train_network.py @@ -469,7 +469,7 @@ class NetworkTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) wav_loss = None - if args.wavelet_loss_alpha: + if args.wavelet_loss: if args.wavelet_loss_rectified_flow: # Calculate flow-based clean estimate using the target flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target