mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
reduce memory usage in sample image generation
This commit is contained in:
@@ -402,9 +402,6 @@ def sample_images(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
@@ -450,8 +447,6 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
@@ -531,12 +526,19 @@ def sample_image_inference(
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# sample image
|
||||
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
|
||||
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
||||
clean_memory_on_device(accelerator.device)
|
||||
with accelerator.autocast():
|
||||
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
|
||||
|
||||
# latent to image
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(accelerator.device)
|
||||
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
||||
image = vae.decode(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
image = image.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
|
||||
Reference in New Issue
Block a user