diff --git a/train_network.py b/train_network.py index 8746ae26..7a589308 100644 --- a/train_network.py +++ b/train_network.py @@ -1035,9 +1035,10 @@ class NetworkTrainer: if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + example_tuple = (latents.detach().clone(), batch["captions"]) if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0: accelerator.wait_for_everyone() - example_tuple = (latents, batch["captions"]) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 @@ -1095,7 +1096,6 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs == 0: - example_tuple = (latents, batch["captions"]) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch