Merge pull request #1054 from akx/mps

Device support improvements (MPS)
This commit is contained in:
Kohya S
2024-01-31 21:30:12 +09:00
committed by GitHub
22 changed files with 91 additions and 75 deletions

View File

@@ -1,6 +1,5 @@
import importlib
import argparse
import gc
import math
import os
import sys
@@ -14,6 +13,7 @@ from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
init_ipex()
@@ -266,9 +266,7 @@ class NetworkTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()
accelerator.wait_for_everyone()