diff --git a/train_network.py b/train_network.py index 55be9601..885203cf 100644 --- a/train_network.py +++ b/train_network.py @@ -1566,9 +1566,6 @@ class NetworkTrainer: } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) - if accelerator.sync_gradients: - batch_size = 0 # reset batch size - restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep @@ -1576,6 +1573,9 @@ class NetworkTrainer: accelerator.unwrap_model(network).train() progress_bar.unpause() + if accelerator.sync_gradients: + batch_size = 0 # reset batch size + if global_step >= args.max_train_steps: break