mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add clean_memory_on_device and use it from training
This commit is contained in:
@@ -31,6 +31,21 @@ def clean_memory():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def clean_memory_on_device(device: torch.device):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_preferred_device() -> torch.device:
|
||||
r"""
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
@@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
clean_memory()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
@@ -32,7 +32,7 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -2285,7 +2285,7 @@ def cache_batch_latents(
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory()
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
@@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
clean_memory()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
@@ -4695,7 +4695,7 @@ def sample_images_common(
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(distributed_state.device)
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
@@ -4752,7 +4752,11 @@ def sample_images_common(
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
@@ -4774,8 +4778,10 @@ def sample_images_common(
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
@@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
|
||||
pass
|
||||
# endregion
|
||||
|
||||
# # clear pipeline and cache to reduce vram usage
|
||||
# del pipeline
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
# region 前処理用
|
||||
|
||||
Reference in New Issue
Block a user