mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor device determination to function; add MPS fallback
This commit is contained in:
@@ -10,6 +10,7 @@ from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import get_preferred_device
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
@@ -85,7 +86,7 @@ if __name__ == "__main__":
|
||||
guidance_scale = 7
|
||||
seed = None # 1
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE = get_preferred_device()
|
||||
DTYPE = torch.float16 # bfloat16 may work
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user