diff --git a/train_network.py b/train_network.py index 6f1652fd..e735c582 100644 --- a/train_network.py +++ b/train_network.py @@ -1381,6 +1381,8 @@ class NetworkTrainer: and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1429,6 +1431,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + optimizer_train_fn() + if global_step >= args.max_train_steps: break @@ -1438,6 +1442,8 @@ class NetworkTrainer: ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, @@ -1493,6 +1499,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + optimizer_train_fn() + # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}