set dtype before calling ae closes #1562

This commit is contained in:
Kohya S
2024-09-05 12:20:07 +09:00
parent 90ed2dfb52
commit d9129522a6

View File

@@ -651,7 +651,7 @@ def train(args):
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"])
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):