Refactor memory cleaning into a single function

This commit is contained in:
Aarni Koskela
2024-01-16 14:47:44 +02:00
parent 2e4bee6f24
commit afc38707d5
15 changed files with 46 additions and 65 deletions

View File

@@ -1,7 +1,6 @@
# DreamBooth training
# XXX dropped option: fine_tune
import gc
import argparse
import itertools
import math
@@ -12,6 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
init_ipex()
@@ -138,9 +138,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()
accelerator.wait_for_everyone()