From e793d7780d779855f23210d1c88368fd9286666e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 17:31:01 +0900 Subject: [PATCH] reduce peak VRAM in sample gen --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1377997c..32198774 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list