diff --git a/train_network.py b/train_network.py index 6953bb17..ea3d5841 100644 --- a/train_network.py +++ b/train_network.py @@ -1036,6 +1036,8 @@ class NetworkTrainer: self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + clean_memory_on_device(accelerator.device) + # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() @@ -1092,6 +1094,8 @@ class NetworkTrainer: self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + clean_memory_on_device(accelerator.device) + # end of epoch # metadata["ss_epoch"] = str(num_train_epochs)