mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Refactor device determination to function; add MPS fallback
This commit is contained in:
@@ -15,6 +15,8 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
@@ -255,7 +257,7 @@ def create_upscaler(**kwargs):
|
||||
|
||||
# another interface: upscale images with a model for given images from command line
|
||||
def upscale_images(args: argparse.Namespace):
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEVICE = get_preferred_device()
|
||||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user