From 8c7d05afd282912fda04fe60cfb644b543d5a634 Mon Sep 17 00:00:00 2001 From: liubo0902 <38622806+liubo0902@users.noreply.github.com> Date: Wed, 20 Dec 2023 09:35:17 +0800 Subject: [PATCH] speed up latents nan replace --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 1cbed2e7..2d53858b 100644 --- a/train_network.py +++ b/train_network.py @@ -767,7 +767,7 @@ class NetworkTrainer: # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor b_size = latents.shape[0]