From e3c43bda49ec8c5a5cb784e29f8610f1ebff0a66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 20:35:47 +0900 Subject: [PATCH] reduce memory usage in sample image generation --- library/sd3_train_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9282482d..af8ecf2c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -402,9 +402,6 @@ def sample_images( except Exception: pass - org_vae_device = vae.device # will be on cpu - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device - if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): @@ -450,8 +447,6 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - clean_memory_on_device(accelerator.device) @@ -531,12 +526,19 @@ def sample_image_inference( neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) - latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + clean_memory_on_device(accelerator.device) + with accelerator.autocast(): + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) # latent to image - with torch.no_grad(): - image = vae.decode(latents) + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)