reduce memory usage in sample image generation

This commit is contained in:
Kohya S
2024-10-24 20:35:47 +09:00
parent 623017f716
commit e3c43bda49

View File

@@ -402,9 +402,6 @@ def sample_images(
except Exception: except Exception:
pass pass
org_vae_device = vae.device # will be on cpu
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
if distributed_state.num_processes <= 1: if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(): with torch.no_grad():
@@ -450,8 +447,6 @@ def sample_images(
if cuda_rng_state is not None: if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -531,12 +526,19 @@ def sample_image_inference(
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image # sample image
clean_memory_on_device(accelerator.device)
with accelerator.autocast():
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
# latent to image # latent to image
with torch.no_grad(): clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents) image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float() image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)