Refactor device determination to function; add MPS fallback

This commit is contained in:
Aarni Koskela
2024-01-16 15:01:59 +02:00
parent afc38707d5
commit 478156b4f7
10 changed files with 48 additions and 14 deletions

View File

@@ -14,7 +14,9 @@ from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from library.device_utils import get_preferred_device
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(
[