call optimizer eval/train fn before/after validation

This commit is contained in:
Kohya S
2025-01-27 21:22:11 +09:00
parent 86a2f3fd26
commit b6a3093216

View File

@@ -1381,6 +1381,8 @@ class NetworkTrainer:
and global_step % args.validate_every_n_steps == 0 and global_step % args.validate_every_n_steps == 0
) )
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
) )
@@ -1429,6 +1431,8 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
optimizer_train_fn()
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
@@ -1438,6 +1442,8 @@ class NetworkTrainer:
) )
if should_validate_epoch and len(val_dataloader) > 0: if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), range(validation_steps),
smoothing=0, smoothing=0,
@@ -1493,6 +1499,8 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
optimizer_train_fn()
# END OF EPOCH # END OF EPOCH
if is_tracking: if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}