diff --git a/train_network.py b/train_network.py index f66cdeb4..d6bc66ed 100644 --- a/train_network.py +++ b/train_network.py @@ -389,7 +389,18 @@ class NetworkTrainer: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + else: + chunks = [ + batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) + ] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + list_latents.append(chunk) + latents = torch.cat(list_latents, dim=0) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)):