mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add comment for get_preferred_device
This commit is contained in:
@@ -15,6 +15,7 @@ except Exception:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # noqa
|
import intel_extension_for_pytorch as ipex # noqa
|
||||||
|
|
||||||
HAS_XPU = torch.xpu.is_available()
|
HAS_XPU = torch.xpu.is_available()
|
||||||
except Exception:
|
except Exception:
|
||||||
HAS_XPU = False
|
HAS_XPU = False
|
||||||
@@ -32,6 +33,9 @@ def clean_memory():
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def get_preferred_device() -> torch.device:
|
def get_preferred_device() -> torch.device:
|
||||||
|
r"""
|
||||||
|
Do not call this function from training scripts. Use accelerator.device instead.
|
||||||
|
"""
|
||||||
if HAS_CUDA:
|
if HAS_CUDA:
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
elif HAS_XPU:
|
elif HAS_XPU:
|
||||||
@@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device:
|
|||||||
print(f"get_preferred_device() -> {device}")
|
print(f"get_preferred_device() -> {device}")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def init_ipex():
|
def init_ipex():
|
||||||
"""
|
"""
|
||||||
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
||||||
@@ -54,6 +59,7 @@ def init_ipex():
|
|||||||
try:
|
try:
|
||||||
if HAS_XPU:
|
if HAS_XPU:
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
is_initialized, error_message = ipex_init()
|
is_initialized, error_message = ipex_init()
|
||||||
if not is_initialized:
|
if not is_initialized:
|
||||||
print("failed to initialize ipex:", error_message)
|
print("failed to initialize ipex:", error_message)
|
||||||
|
|||||||
Reference in New Issue
Block a user