From 53cc3583df729dc69349b687cb52caed786ff3b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 21:46:12 +0900 Subject: [PATCH] fix potential issue with dtype --- fine_tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index 637a729a..50549878 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -275,7 +275,7 @@ def train(args): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()