From d9129522a6effea7077f18cdea0ee733a5ac7cb0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 12:20:07 +0900 Subject: [PATCH] set dtype before calling ae closes #1562 --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 32a36f03..0293b7be 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def train(args): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)):