add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -12,7 +12,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -265,7 +265,7 @@ class NetworkTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()