mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
reduce memory usage in sample image generation
This commit is contained in:
@@ -402,9 +402,6 @@ def sample_images(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
org_vae_device = vae.device # will be on cpu
|
|
||||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
|
||||||
|
|
||||||
if distributed_state.num_processes <= 1:
|
if distributed_state.num_processes <= 1:
|
||||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -450,8 +447,6 @@ def sample_images(
|
|||||||
if cuda_rng_state is not None:
|
if cuda_rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
|
||||||
vae.to(org_vae_device)
|
|
||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
|
||||||
@@ -531,12 +526,19 @@ def sample_image_inference(
|
|||||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||||
|
|
||||||
# sample image
|
# sample image
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
with accelerator.autocast():
|
||||||
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
|
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
|
||||||
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
|
||||||
|
|
||||||
# latent to image
|
# latent to image
|
||||||
with torch.no_grad():
|
clean_memory_on_device(accelerator.device)
|
||||||
|
org_vae_device = vae.device # will be on cpu
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
||||||
image = vae.decode(latents)
|
image = vae.decode(latents)
|
||||||
|
vae.to(org_vae_device)
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
image = image.float()
|
image = image.float()
|
||||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user