Merge pull request #1009 from liubo0902/main

speed up latents nan replace
This commit is contained in:
Kohya S
2023-12-21 21:37:16 +09:00
committed by GitHub

View File

@@ -750,7 +750,7 @@ class NetworkTrainer:
# NaNが含まれていれば警告を表示し0に置き換える # NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)): if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros") accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor latents = latents * self.vae_scale_factor
b_size = latents.shape[0] b_size = latents.shape[0]