diff --git a/train_network.py b/train_network.py index 8bfb1925..99c58f49 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval()