speed up latents nan replace

This commit is contained in:
liubo0902
2023-12-20 09:35:17 +08:00
committed by GitHub
parent 0908c5414d
commit 8c7d05afd2

View File

@@ -767,7 +767,7 @@ class NetworkTrainer:
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]