diff --git a/sdxl_train.py b/sdxl_train.py index aa2eb5df..8983673d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -487,7 +487,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index cb97859f..18c6bd05 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -394,7 +394,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 87f30301..6ae5377b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -363,7 +363,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: