diff --git a/flux_train.py b/flux_train.py index 32a36f03..0293b7be 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def train(args): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)):