diff --git a/train_network.py b/train_network.py index 82110066..d549378c 100644 --- a/train_network.py +++ b/train_network.py @@ -989,7 +989,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1015,7 +1015,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)