add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -31,6 +31,21 @@ def clean_memory():
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""

View File

@@ -4,7 +4,7 @@ import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
@@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -32,7 +32,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -2285,7 +2285,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent
if not HIGH_VRAM:
clean_memory()
clean_memory_on_device(vae.device)
def cache_batch_text_encoder_outputs(
@@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4695,7 +4695,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device)
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
@@ -4752,7 +4752,11 @@ def sample_images_common(
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
@@ -4774,8 +4778,10 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
@@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
pass
# endregion
# # clear pipeline and cache to reduce vram usage
# del pipeline
# torch.cuda.empty_cache()
# region 前処理用