From e24d9606a2c890cf9f015239e04c82eecbda8bce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 11:10:52 +0900 Subject: [PATCH] add clean_memory_on_device and use it from training --- fine_tune.py | 4 ++-- library/device_utils.py | 15 +++++++++++++++ library/sdxl_train_util.py | 4 ++-- library/train_util.py | 24 +++++++++++++----------- sdxl_train.py | 6 +++--- sdxl_train_control_net_lllite.py | 6 +++--- sdxl_train_control_net_lllite_old.py | 6 +++--- sdxl_train_network.py | 6 +++--- train_controlnet.py | 6 +++--- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 13 files changed, 55 insertions(+), 38 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index a6a5c1e2..72bae972 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -156,7 +156,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/library/device_utils.py b/library/device_utils.py index 546fb386..8823c5d9 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -31,6 +31,21 @@ def clean_memory(): torch.mps.empty_cache() +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: r""" diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 0ecf4feb..aa36c544 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -4,7 +4,7 @@ import os from typing import Optional import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate import init_empty_weights @@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index bf224b3e..47a767bd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -32,7 +32,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -2285,7 +2285,7 @@ def cache_batch_latents( info.latents_flipped = flipped_latent if not HIGH_VRAM: - clean_memory() + clean_memory_on_device(vae.device) def cache_batch_text_encoder_outputs( @@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4695,7 +4695,7 @@ def sample_images_common( distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4752,7 +4752,11 @@ def sample_images_common( # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. @@ -4774,8 +4778,10 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p pass # endregion - # # clear pipeline and cache to reduce vram usage - # del pipeline - # torch.cuda.empty_cache() - # region 前処理用 diff --git a/sdxl_train.py b/sdxl_train.py index b0dcdbe9..b7789d4b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -250,7 +250,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -403,7 +403,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1b069624..c37b847c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,7 +14,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -162,7 +162,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -287,7 +287,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b74e3b90..35747c1e 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,7 +11,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -161,7 +161,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -260,7 +260,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 205b526e..fd04a572 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,7 +1,7 @@ import argparse import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util @@ -64,7 +64,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -79,7 +79,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - clean_memory() + clean_memory_on_device(accelerator.device) if not args.lowram: print("move vae and unet back to original device") diff --git a/train_controlnet.py b/train_controlnet.py index e7a06ae1..6ff2e781 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,7 +11,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -217,8 +217,8 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - clean_memory() - + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() if args.gradient_checkpointing: diff --git a/train_db.py b/train_db.py index f6795dce..085ef2ce 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -136,7 +136,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index f8dd6ab2..046edb5b 100644 --- a/train_network.py +++ b/train_network.py @@ -12,7 +12,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -265,7 +265,7 @@ class NetworkTrainer: with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d18837bd..c34cfc96 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,7 +7,7 @@ import toml from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -361,7 +361,7 @@ class TextualInversionTrainer: with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 9d4b0aef..b6320082 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,7 +8,7 @@ from multiprocessing import Value from tqdm import tqdm import torch -from library.device_utils import init_ipex, clean_memory +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -284,7 +284,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - clean_memory() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone()