diff --git a/train_network.py b/train_network.py index f870734f..ce34f26d 100644 --- a/train_network.py +++ b/train_network.py @@ -1439,6 +1439,9 @@ class NetworkTrainer: ) for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() @@ -1447,7 +1450,6 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs)