add network.train()/eval() for validation

This commit is contained in:
Kohya S
2025-01-27 21:35:43 +09:00
parent b6a3093216
commit 29f31d005f

View File

@@ -1276,7 +1276,7 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
# TRAINING
skipped_dataloader = None
@@ -1382,6 +1382,7 @@ class NetworkTrainer:
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
val_progress_bar = tqdm(
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
@@ -1432,6 +1433,7 @@ class NetworkTrainer:
accelerator.log(logs, step=global_step)
optimizer_train_fn()
accelerator.unwrap_model(network).train()
if global_step >= args.max_train_steps:
break
@@ -1443,6 +1445,7 @@ class NetworkTrainer:
if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
val_progress_bar = tqdm(
range(validation_steps),
@@ -1500,6 +1503,7 @@ class NetworkTrainer:
accelerator.log(logs, step=global_step)
optimizer_train_fn()
accelerator.unwrap_model(network).train()
# END OF EPOCH
if is_tracking: