Refactor memory cleaning into a single function

This commit is contained in:
Aarni Koskela
2024-01-16 14:47:44 +02:00
parent 2e4bee6f24
commit afc38707d5
15 changed files with 46 additions and 65 deletions

View File

@@ -2,7 +2,6 @@
# XXX dropped option: hypernetwork training # XXX dropped option: hypernetwork training
import argparse import argparse
import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
@@ -11,6 +10,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -158,9 +158,7 @@ def train(args):
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -66,6 +66,7 @@ import diffusers
import numpy as np import numpy as np
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -888,8 +889,7 @@ class PipelineLike:
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_latent_dist.sample(generator=generator)
else: else:
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
init_latents = [] init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode( init_latent_dist = self.vae.encode(
@@ -1047,8 +1047,7 @@ class PipelineLike:
if vae_batch_size >= batch_size: if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
else: else:
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
images = [] images = []
for i in tqdm(range(0, batch_size, vae_batch_size)): for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append( images.append(

9
library/device_utils.py Normal file
View File

@@ -0,0 +1,9 @@
import gc
import torch
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import gc
import math import math
import os import os
from typing import Optional from typing import Optional
@@ -8,6 +7,7 @@ from accelerate import init_empty_weights
from tqdm import tqdm from tqdm import tqdm
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device) unet.to(accelerator.device)
vae.to(accelerator.device) vae.to(accelerator.device)
gc.collect() clean_memory()
torch.cuda.empty_cache()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -20,7 +20,6 @@ from typing import (
Union, Union,
) )
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob import glob
import math import math
import os import os
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor # from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork # from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent info.latents_flipped = flipped_latent
# FIXME this slows down caching a lot, specify this as an option # FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
def cache_batch_text_encoder_outputs( def cache_batch_text_encoder_outputs(
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device) unet.to(accelerator.device)
vae.to(accelerator.device) vae.to(accelerator.device)
gc.collect() clean_memory()
torch.cuda.empty_cache()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4816,7 +4814,7 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
torch.cuda.empty_cache() clean_memory()
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
if cuda_rng_state is not None: if cuda_rng_state is not None:

View File

@@ -18,6 +18,7 @@ import diffusers
import numpy as np import numpy as np
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -640,8 +641,7 @@ class PipelineLike:
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_latent_dist.sample(generator=generator)
else: else:
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
init_latents = [] init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode( init_latent_dist = self.vae.encode(
@@ -780,8 +780,7 @@ class PipelineLike:
if vae_batch_size >= batch_size: if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample image = self.vae.decode(latents.to(self.vae.dtype)).sample
else: else:
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
images = [] images = []
for i in tqdm(range(0, batch_size, vae_batch_size)): for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append( images.append(
@@ -796,8 +795,7 @@ class PipelineLike:
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
if output_type == "pil": if output_type == "pil":
# image = self.numpy_to_pil(image) # image = self.numpy_to_pil(image)

View File

@@ -1,7 +1,6 @@
# training with captions # training with captions
import argparse import argparse
import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
@@ -11,6 +10,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -252,9 +252,7 @@ def train(args):
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -407,8 +405,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32) text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
else: else:
# make sure Text Encoders are on GPU # make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)

View File

@@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward # training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse import argparse
import gc
import json import json
import math import math
import os import os
@@ -15,6 +14,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -164,9 +164,7 @@ def train(args):
accelerator.is_main_process, accelerator.is_main_process,
) )
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -291,8 +289,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32) text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
else: else:
# make sure Text Encoders are on GPU # make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import gc
import json import json
import math import math
import os import os
@@ -12,6 +11,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -163,9 +163,7 @@ def train(args):
accelerator.is_main_process, accelerator.is_main_process,
) )
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -264,8 +262,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32) text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
else: else:
# make sure Text Encoders are on GPU # make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)

View File

@@ -1,6 +1,7 @@
import argparse import argparse
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -65,8 +66,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
org_unet_device = unet.device org_unet_device = unet.device
vae.to("cpu") vae.to("cpu")
unet.to("cpu") unet.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
# When TE is not be trained, it will not be prepared so we need to use explicit autocast # When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast(): with accelerator.autocast():
@@ -81,8 +81,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32) text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
if not args.lowram: if not args.lowram:
print("move vae and unet back to original device") print("move vae and unet back to original device")

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import gc
import json import json
import math import math
import os import os
@@ -12,6 +11,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -219,9 +219,7 @@ def train(args):
accelerator.is_main_process, accelerator.is_main_process,
) )
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -1,7 +1,6 @@
# DreamBooth training # DreamBooth training
# XXX dropped option: fine_tune # XXX dropped option: fine_tune
import gc
import argparse import argparse
import itertools import itertools
import math import math
@@ -12,6 +11,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -138,9 +138,7 @@ def train(args):
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -1,6 +1,5 @@
import importlib import importlib
import argparse import argparse
import gc
import math import math
import os import os
import sys import sys
@@ -14,6 +13,7 @@ from tqdm import tqdm
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -266,9 +266,7 @@ class NetworkTrainer:
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
@@ -8,6 +7,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -363,9 +363,7 @@ class TextualInversionTrainer:
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -1,6 +1,5 @@
import importlib import importlib
import argparse import argparse
import gc
import math import math
import os import os
import toml import toml
@@ -9,6 +8,7 @@ from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex from library.ipex_interop import init_ipex
init_ipex() init_ipex()
@@ -286,9 +286,7 @@ def train(args):
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): clean_memory()
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()