Refactor device determination to function; add MPS fallback

This commit is contained in:
Aarni Koskela
2024-01-16 15:01:59 +02:00
parent afc38707d5
commit 478156b4f7
10 changed files with 48 additions and 14 deletions

View File

@@ -15,8 +15,9 @@ from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
IMAGE_SIZE = 384

View File

@@ -10,9 +10,9 @@ from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.device_utils import get_preferred_device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),

View File

@@ -14,7 +14,9 @@ from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from library.device_utils import get_preferred_device
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(
[