add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -161,7 +161,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -260,7 +260,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)