From 0fd2c8af610c58951da8b16eeb3f4344e3a1d3a4 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 30 Mar 2025 04:34:42 +0800 Subject: [PATCH] Update train_util.py --- library/train_util.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bbfd22af..1054986d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5710,19 +5710,17 @@ def sample_image_inference( controlnet=controlnet, controlnet_image=controlnet_image, ) - - if torch.cuda.is_available(): - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() - + logger.info(f"latents: {latents.shape}") + clean_memory_on_device(accelerator.device) + image = pipeline.latents_to_image(latents)[0] if "original_lantent" in prompt_dict: #Prevent out of VRAM error - if torch.cuda.is_available(): - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + + clean_memory_on_device(accelerator.device) original_latent = prompt_dict.get("original_lantent") + logger.info(f"original_latent: {original_latent.shape}") original_image = pipeline.latents_to_image(original_latent)[0] text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2) new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))