fix: simplify validation step condition in NetworkTrainer

This commit is contained in:
Kohya S
2025-02-11 21:53:57 +09:00
parent cd80752175
commit 76b761943b

View File

@@ -1414,12 +1414,9 @@ class NetworkTrainer:
)
accelerator.log(logs, step=global_step)
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()