From 47359b8fac9602415f56b1f7e3f25a00255a1d78 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:17:40 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 82110066..d549378c 100644 --- a/train_network.py +++ b/train_network.py @@ -989,7 +989,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1015,7 +1015,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)