add comment for get_preferred_device

This commit is contained in:
Kohya S
2024-02-08 20:58:54 +09:00
parent 5cca1fdc40
commit 74fe0453b2

View File

@@ -15,6 +15,7 @@ except Exception:
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
@@ -32,6 +33,9 @@ def clean_memory():
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
@@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device:
print(f"get_preferred_device() -> {device}")
return device
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
@@ -54,10 +59,11 @@ def init_ipex():
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)
print("failed to initialize ipex:", e)