diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index a44531f3..6df52056 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -1,6 +1,7 @@ import os import sys import torch +from packaging import version try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import has_ipex = True @@ -8,7 +9,7 @@ except Exception: has_ipex = False from .hijacks import ipex_hijacks -torch_version = float(torch.__version__[:3]) +torch_version = version.parse(torch.__version__) # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -71,7 +72,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.__file__ = torch.xpu.__file__ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - if torch_version < 2.3: + if torch_version < version.parse("2.3"): torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork @@ -114,14 +115,14 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.threading = torch.xpu.threading torch.cuda.traceback = torch.xpu.traceback - if torch_version < 2.5: + if torch_version < version.parse("2.5"): torch.cuda.os = torch.xpu.os torch.cuda.Device = torch.xpu.Device torch.cuda.warnings = torch.xpu.warnings torch.cuda.classproperty = torch.xpu.classproperty torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - if torch_version < 2.7: + if torch_version < version.parse("2.7"): torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.List = torch.xpu.List @@ -160,7 +161,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.initial_seed = torch.xpu.initial_seed # C - if torch_version < 2.3: + if torch_version < version.parse("2.3"): torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.major = 12