diff --git a/library/train_util.py b/library/train_util.py index 1377997c..32198774 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list