From 4012fd24f684d4d371a8736d7e65bee307077e33 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 28 Mar 2024 21:08:16 +0300 Subject: [PATCH] IPEX fix pin_memory --- library/ipex/__init__.py | 7 ++++--- library/ipex/hijacks.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 972a3bf6..e5aba693 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.nn.Module.cuda = torch.nn.Module.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized @@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 2024 + ipex._C._DeviceProperties.minor = 0 # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 65089f39..d3cef827 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs): else: return original_Tensor_cuda(self, device, *args, **kwargs) +original_Tensor_pin_memory = torch.Tensor.pin_memory +@wraps(torch.Tensor.pin_memory) +def Tensor_pin_memory(self, device=None, *args, **kwargs): + if device is None: + device = "xpu" + if check_device(device): + return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_pin_memory(self, device, *args, **kwargs) + original_UntypedStorage_init = torch.UntypedStorage.__init__ @wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): @@ -259,10 +269,12 @@ def torch_Generator(device=None): original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): + if map_location is None: + map_location = "xpu" if check_device(map_location): - return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs) + return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs) else: - return original_torch_load(f, map_location=map_location, *args, **kwargs) + return original_torch_load(f, *args, map_location=map_location, **kwargs) # Hijack Functions: @@ -270,6 +282,7 @@ def ipex_hijacks(): torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda + torch.Tensor.pin_memory = Tensor_pin_memory torch.UntypedStorage.__init__ = UntypedStorage_init torch.UntypedStorage.cuda = UntypedStorage_cuda torch.empty = torch_empty