diff --git a/library/device_utils.py b/library/device_utils.py index 93371ca6..546fb386 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -15,6 +15,7 @@ except Exception: try: import intel_extension_for_pytorch as ipex # noqa + HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -32,6 +33,9 @@ def clean_memory(): @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ if HAS_CUDA: device = torch.device("cuda") elif HAS_XPU: @@ -43,6 +47,7 @@ def get_preferred_device() -> torch.device: print(f"get_preferred_device() -> {device}") return device + def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. @@ -54,10 +59,11 @@ def init_ipex(): try: if HAS_XPU: from library.ipex import ipex_init + is_initialized, error_message = ipex_init() if not is_initialized: print("failed to initialize ipex:", error_message) else: return except Exception as e: - print("failed to initialize ipex:", e) \ No newline at end of file + print("failed to initialize ipex:", e)