diff --git a/train_network.py b/train_network.py index e735c582..9b8036f8 100644 --- a/train_network.py +++ b/train_network.py @@ -1276,7 +1276,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1382,6 +1382,7 @@ class NetworkTrainer: ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" @@ -1432,6 +1433,7 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() if global_step >= args.max_train_steps: break @@ -1443,6 +1445,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), @@ -1500,6 +1503,7 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() # END OF EPOCH if is_tracking: