Update train_network.py

This commit is contained in:
DKnight54
2025-03-28 19:46:23 +08:00
committed by GitHub
parent 7c94d386c1
commit c6d986e868

View File

@@ -913,7 +913,7 @@ class NetworkTrainer:
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
if args.train_vae_batch is None or len(batch["latents"]) <= args.train_vae_batch
if args.train_vae_batch is None or len(batch["latents"]) <= args.train_vae_batch:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)