Refactor device determination to function; add MPS fallback

This commit is contained in:
Aarni Koskela
2024-01-16 15:01:59 +02:00
parent afc38707d5
commit 478156b4f7
10 changed files with 48 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ from einops import repeat
import numpy as np
import torch
from library.device_utils import get_preferred_device
from library.ipex_interop import init_ipex
init_ipex()
@@ -85,7 +86,7 @@ if __name__ == "__main__":
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()