reduce peak VRAM in sample gen

This commit is contained in:
Kohya S
2024-02-04 17:31:01 +09:00
parent 2f9a344297
commit e793d7780d

View File

@@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
controlnet=controlnet,
controlnet_image=controlnet_image,
)
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
image = pipeline.latents_to_image(latents)[0]
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list