import functools import gc import torch try: HAS_CUDA = torch.cuda.is_available() except Exception: HAS_CUDA = False try: HAS_MPS = torch.backends.mps.is_available() except Exception: HAS_MPS = False try: import intel_extension_for_pytorch as ipex # noqa HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False def clean_memory(): gc.collect() if HAS_CUDA: torch.cuda.empty_cache() if HAS_XPU: torch.xpu.empty_cache() if HAS_MPS: torch.mps.empty_cache() @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: if HAS_CUDA: device = torch.device("cuda") elif HAS_XPU: device = torch.device("xpu") elif HAS_MPS: device = torch.device("mps") else: device = torch.device("cpu") print(f"get_preferred_device() -> {device}") return device def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. This function should run right after importing torch and before doing anything else. If IPEX is not available, this function does nothing. """ try: if HAS_XPU: from library.ipex import ipex_init is_initialized, error_message = ipex_init() if not is_initialized: print("failed to initialize ipex:", error_message) else: return except Exception as e: print("failed to initialize ipex:", e)