add comment for get_preferred_device

This commit is contained in:
Kohya S
2024-02-08 20:58:54 +09:00
parent 5cca1fdc40
commit 74fe0453b2

View File

@@ -15,6 +15,7 @@ except Exception:
try: try:
import intel_extension_for_pytorch as ipex # noqa import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available() HAS_XPU = torch.xpu.is_available()
except Exception: except Exception:
HAS_XPU = False HAS_XPU = False
@@ -32,6 +33,9 @@ def clean_memory():
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device: def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA: if HAS_CUDA:
device = torch.device("cuda") device = torch.device("cuda")
elif HAS_XPU: elif HAS_XPU:
@@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device:
print(f"get_preferred_device() -> {device}") print(f"get_preferred_device() -> {device}")
return device return device
def init_ipex(): def init_ipex():
""" """
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
@@ -54,6 +59,7 @@ def init_ipex():
try: try:
if HAS_XPU: if HAS_XPU:
from library.ipex import ipex_init from library.ipex import ipex_init
is_initialized, error_message = ipex_init() is_initialized, error_message = ipex_init()
if not is_initialized: if not is_initialized:
print("failed to initialize ipex:", error_message) print("failed to initialize ipex:", error_message)