From 1f9ba40b8b70fd08e6b87a70727d5e789666a925 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:32:07 -0500 Subject: [PATCH] Add step break for validation epoch. Remove unused variable --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f870734f..ce34f26d 100644 --- a/train_network.py +++ b/train_network.py @@ -1439,6 +1439,9 @@ class NetworkTrainer: ) for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() @@ -1447,7 +1450,6 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs)