Add step break for validation epoch. Remove unused variable

This commit is contained in:
rockerBOO
2025-01-03 15:32:07 -05:00
parent 695f38962c
commit 1f9ba40b8b

View File

@@ -1439,6 +1439,9 @@ class NetworkTrainer:
)
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
current_loss = loss.detach().item()
@@ -1447,7 +1450,6 @@ class NetworkTrainer:
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
if is_tracking:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
accelerator.log(logs)