diff --git a/train_network.py b/train_network.py index 7bf125dc..6953bb17 100644 --- a/train_network.py +++ b/train_network.py @@ -912,14 +912,22 @@ class NetworkTrainer: if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: - with torch.no_grad(): - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) - + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + with torch.no_grad(): + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + else: + chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + # latentに変換 + list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)) + latents = torch.cat(list_latents, dim=0) # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor # get multiplier for each sample