fix sigma view

This commit is contained in:
rockerBOO
2025-10-12 16:26:32 -04:00
parent d90b293792
commit cda4c076fc

View File

@@ -478,7 +478,7 @@ class NetworkTrainer:
if args.wavelet_loss:
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigma, noise_pred, noise):
sigmas = sigma.expand(noise_pred.size(0), -1, -1, -1)
sigmas = sigma.view(-1, 1, 1, 1).expand(noise_pred.size(0), -1, -1, -1)
if denoise_latents:
if self.is_flow_matching:
# denoise latents to use for wavelet loss