From dbd835ee4b3198f781d77b77cd7918de669c441d Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 8 Apr 2025 21:57:16 +0900 Subject: [PATCH] train: Optimize VAE encoding by handling batch sizes for images --- sdxl_train.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index a60f6df6..84570144 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -640,14 +640,23 @@ def train(args): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: - with torch.no_grad(): - # latentに変換 - latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + with torch.no_grad(): + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + else: + chunks = [ + batch["images"][i : i + args.vae_batch_size] + for i in range(0, len(batch["images"]), args.vae_batch_size) + ] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + # latentに変換 + list_latents.append( + vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + ) + latents = torch.cat(list_latents, dim=0) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)