Refactor memory cleaning into a single function

This commit is contained in:
Aarni Koskela
2024-01-16 14:47:44 +02:00
parent 2e4bee6f24
commit afc38707d5
15 changed files with 46 additions and 65 deletions

View File

@@ -20,7 +20,6 @@ from typing import (
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent
# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
def cache_batch_text_encoder_outputs(
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4816,7 +4814,7 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None: