make the device of snr_weight the same as loss

This commit is contained in:
ddPn08
2023-06-01 10:32:34 +09:00
parent c8d209d36c
commit 1f1cae6c5a

View File

@@ -14,7 +14,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss