Refactor memory cleaning into a single function

This commit is contained in:
Aarni Koskela
2024-01-16 14:47:44 +02:00
parent 2e4bee6f24
commit afc38707d5
15 changed files with 46 additions and 65 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
init_ipex()
@@ -65,8 +66,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -81,8 +81,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
if not args.lowram:
print("move vae and unet back to original device")