Refactor device determination to function; add MPS fallback

This commit is contained in:
Aarni Koskela
2024-01-16 15:01:59 +02:00
parent afc38707d5
commit 478156b4f7
10 changed files with 48 additions and 14 deletions

View File

@@ -66,7 +66,7 @@ import diffusers
import numpy as np
import torch
from library.device_utils import clean_memory
from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex
init_ipex()
@@ -2324,7 +2324,7 @@ def main(args):
scheduler.config.clip_sample = True
# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()
# custom pipelineをコピったやつを生成する
if args.vae_slices: