Add no_half_vae for SDXL training, add nan check

This commit is contained in:
Kohya S
2023-06-26 20:38:09 +09:00
parent 56ca5dfa15
commit 2c461e4ad3
3 changed files with 35 additions and 11 deletions

View File

@@ -906,6 +906,11 @@ class BaseDataset(torch.utils.data.Dataset):
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
# check NaN
for info, latents1 in zip(batch, latents):
if torch.isnan(latents1).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
for info, latent in zip(batch, latents):
if cache_to_disk:
np.savez(