Refactor device determination to function; add MPS fallback

This commit is contained in:
Aarni Koskela
2024-01-16 15:01:59 +02:00
parent afc38707d5
commit 478156b4f7
10 changed files with 48 additions and 14 deletions

View File

@@ -15,6 +15,8 @@ from torch import nn
from tqdm import tqdm
from PIL import Image
from library.device_utils import get_preferred_device
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -255,7 +257,7 @@ def create_upscaler(**kwargs):
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)