From 76b761943b5166f496aa1cb8ffbcc2d04469346a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:53:57 +0900 Subject: [PATCH] fix: simplify validation step condition in NetworkTrainer --- train_network.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 8bfb1925..99c58f49 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval()