From f36b545480dbe900ca4d21ae6685d3cc1e567e88 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 6 Feb 2025 19:14:05 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 6ec1553a..7105a1fa 100644 --- a/train_network.py +++ b/train_network.py @@ -1028,8 +1028,9 @@ class NetworkTrainer: if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + if args.sample_every_n_steps is not None and steps % args.sample_every_n_steps != 0: + example_tuple = (latents, batch["captions"]) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1085,7 +1086,9 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs != 0: + example_tuple = (latents, batch["captions"]) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch