diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index a44531f3..92a88e23 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -1,6 +1,7 @@ import os import sys import torch +from packaging import version try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import has_ipex = True @@ -8,7 +9,7 @@ except Exception: has_ipex = False from .hijacks import ipex_hijacks -torch_version = float(torch.__version__[:3]) +torch_version = version.parse(torch.__version__) # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.__path__ = torch.xpu.__path__ torch.cuda.set_stream = torch.xpu.set_stream torch.cuda.torch = torch.xpu.torch - torch.cuda.Union = torch.xpu.Union torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__builtins__ = torch.xpu.__builtins__ @@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.StreamContext = torch.xpu.StreamContext torch.cuda._lazy_call = torch.xpu._lazy_call torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t torch.cuda.__spec__ = torch.xpu.__spec__ torch.cuda.__file__ = torch.xpu.__file__ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - if torch_version < 2.3: + if torch_version < version.parse("2.3"): torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork @@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.threading = torch.xpu.threading torch.cuda.traceback = torch.xpu.traceback - if torch_version < 2.5: + if torch_version < version.parse("2.5"): torch.cuda.os = torch.xpu.os torch.cuda.Device = torch.xpu.Device torch.cuda.warnings = torch.xpu.warnings torch.cuda.classproperty = torch.xpu.classproperty torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - if torch_version < 2.7: + if torch_version < version.parse("2.7"): torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.List = torch.xpu.List + if torch_version < version.parse("2.11"): + torch.cuda._device_t = torch.xpu._device_t + torch.cuda._device = torch.xpu._device + torch.cuda.Union = torch.xpu.Union + # Memory: if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): @@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.initial_seed = torch.xpu.initial_seed # C - if torch_version < 2.3: + if torch_version < version.parse("2.3"): torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.major = 12