fix potential issue with dtype

This commit is contained in:
Kohya S
2023-04-03 21:46:12 +09:00
parent 6f6f9b537f
commit 53cc3583df

View File

@@ -275,7 +275,7 @@ def train(args):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else: else:
# latentに変換 # latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()