fix: simplify validation step condition in NetworkTrainer

This commit is contained in:
Kohya S
2025-02-11 21:53:57 +09:00
parent cd80752175
commit 76b761943b

View File

@@ -1414,12 +1414,9 @@ class NetworkTrainer:
) )
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# VALIDATION PER STEP # VALIDATION PER STEP: global_step is already incremented
should_validate_step = ( # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
args.validate_every_n_steps is not None should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn() optimizer_eval_fn()
accelerator.unwrap_model(network).eval() accelerator.unwrap_model(network).eval()