mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add network.train()/eval() for validation
This commit is contained in:
@@ -1276,7 +1276,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
|
||||||
|
|
||||||
# TRAINING
|
# TRAINING
|
||||||
skipped_dataloader = None
|
skipped_dataloader = None
|
||||||
@@ -1382,6 +1382,7 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
|
accelerator.unwrap_model(network).eval()
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
|
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
|
||||||
@@ -1432,6 +1433,7 @@ class NetworkTrainer:
|
|||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
|
accelerator.unwrap_model(network).train()
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@@ -1443,6 +1445,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if should_validate_epoch and len(val_dataloader) > 0:
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
|
accelerator.unwrap_model(network).eval()
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps),
|
range(validation_steps),
|
||||||
@@ -1500,6 +1503,7 @@ class NetworkTrainer:
|
|||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
|
accelerator.unwrap_model(network).train()
|
||||||
|
|
||||||
# END OF EPOCH
|
# END OF EPOCH
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
|
|||||||
Reference in New Issue
Block a user