Update train_util.py

This commit is contained in:
DKnight54
2025-03-30 04:34:42 +08:00
committed by GitHub
parent 1a957f9243
commit 0fd2c8af61

View File

@@ -5710,19 +5710,17 @@ def sample_image_inference(
controlnet=controlnet,
controlnet_image=controlnet_image,
)
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
logger.info(f"latents: {latents.shape}")
clean_memory_on_device(accelerator.device)
image = pipeline.latents_to_image(latents)[0]
if "original_lantent" in prompt_dict:
#Prevent out of VRAM error
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
original_latent = prompt_dict.get("original_lantent")
logger.info(f"original_latent: {original_latent.shape}")
original_image = pipeline.latents_to_image(original_latent)[0]
text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2)
new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))