set dtype before calling ae closes #1562

This commit is contained in:
Kohya S
2024-09-05 12:20:07 +09:00
parent 90ed2dfb52
commit d9129522a6

View File

@@ -651,7 +651,7 @@ def train(args):
else: else:
with torch.no_grad(): with torch.no_grad():
# encode images to latents. images are [-1, 1] # encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"]) latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える # NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)): if torch.any(torch.isnan(latents)):