From d90b29379204917b01ecebcf02dc401b098ef611 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Oct 2025 16:07:55 -0400 Subject: [PATCH] fix batch expansion of sigmas when batch size > 1 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3cccc19c..3811202e 100644 --- a/train_network.py +++ b/train_network.py @@ -477,7 +477,8 @@ class NetworkTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.wavelet_loss: - def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise): + def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigma, noise_pred, noise): + sigmas = sigma.expand(noise_pred.size(0), -1, -1, -1) if denoise_latents: if self.is_flow_matching: # denoise latents to use for wavelet loss