add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
@@ -64,7 +64,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -79,7 +79,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
if not args.lowram:
print("move vae and unet back to original device")