mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor memory cleaning into a single function
This commit is contained in:
9
library/device_utils.py
Normal file
9
library/device_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def clean_memory():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
@@ -8,6 +7,7 @@ from accelerate import init_empty_weights
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.device_utils import clean_memory
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
|
||||
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.device_utils import clean_memory
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
# FIXME this slows down caching a lot, specify this as an option
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
@@ -4816,7 +4814,7 @@ def sample_images_common(
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
|
||||
Reference in New Issue
Block a user