Compare commits

...

2 Commits

Author SHA1 Message Date
rockerBOO
cda4c076fc fix sigma view 2025-10-12 16:26:32 -04:00
rockerBOO
d90b293792 fix batch expansion of sigmas when batch size > 1 2025-10-12 16:11:41 -04:00

View File

@@ -477,7 +477,8 @@ class NetworkTrainer:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.wavelet_loss:
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigma, noise_pred, noise):
sigmas = sigma.view(-1, 1, 1, 1).expand(noise_pred.size(0), -1, -1, -1)
if denoise_latents:
if self.is_flow_matching:
# denoise latents to use for wavelet loss