From 74fe0453b204e5b2c718e821c692595ff2e76c35 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 8 Feb 2024 20:58:54 +0900 Subject: [PATCH] add comment for get_preferred_device --- library/device_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/device_utils.py b/library/device_utils.py index 93371ca6..546fb386 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -15,6 +15,7 @@ except Exception: try: import intel_extension_for_pytorch as ipex # noqa + HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -32,6 +33,9 @@ def clean_memory(): @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ if HAS_CUDA: device = torch.device("cuda") elif HAS_XPU: @@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device: print(f"get_preferred_device() -> {device}") return device + def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. @@ -54,10 +59,11 @@ def init_ipex(): try: if HAS_XPU: from library.ipex import ipex_init + is_initialized, error_message = ipex_init() if not is_initialized: print("failed to initialize ipex:", error_message) else: return except Exception as e: - print("failed to initialize ipex:", e) \ No newline at end of file + print("failed to initialize ipex:", e)