Refactor memory cleaning into a single function

This commit is contained in:
Aarni Koskela
2024-01-16 14:47:44 +02:00
parent 2e4bee6f24
commit afc38707d5
15 changed files with 46 additions and 65 deletions

9
library/device_utils.py Normal file
View File

@@ -0,0 +1,9 @@
import gc
import torch
def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

View File

@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from typing import Optional
@@ -8,6 +7,7 @@ from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -20,7 +20,6 @@ from typing import (
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent
# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
def cache_batch_text_encoder_outputs(
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4816,7 +4814,7 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None: