From eb2d9abff666945c03f4745185b4cdf85ef3ce11 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:36:43 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index dd67ae58..dd7748ba 100644 --- a/train_network.py +++ b/train_network.py @@ -1031,12 +1031,12 @@ class NetworkTrainer: keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes - example_tuple = (latents, batch["captions"]) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0: + accelerator.wait_for_everyone() example_tuple = (latents, batch["captions"]) self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) @@ -1095,6 +1095,7 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs == 0: + accelerator.wait_for_everyone() example_tuple = (latents, batch["captions"]) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)