From a9c6182b3fb61ad73375497f624e873e097242b8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 19:52:31 +0300 Subject: [PATCH] Cleanup IPEX libs --- library/ipex/__init__.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index cda32ccb..662572c8 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -4,13 +4,12 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -91,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None @@ -113,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -124,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() @@ -139,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: + # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True @@ -166,7 +165,11 @@ def ipex_init(): # pylint: disable=too-many-statements ipex_hijacks() if not torch.xpu.has_fp64_dtype(): - attention_init() + try: + from .attention import attention_init + attention_init() + except Exception: # pylint: disable=broad-exception-caught + pass try: from .diffusers import ipex_diffusers ipex_diffusers()