From b99cd2a9209a6bd46c43dc6d46965be06348fd01 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Wed, 20 Sep 2023 17:16:06 +0300 Subject: [PATCH] Update getDeviceIdListForCard --- library/ipex/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 9ec69012..19ec8eea 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -16,7 +16,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.device = torch.xpu.device torch.cuda.device_count = torch.xpu.device_count torch.cuda.device_of = torch.xpu.device_of - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard torch.cuda.get_device_name = torch.xpu.get_device_name torch.cuda.get_device_properties = torch.xpu.get_device_properties torch.cuda.init = torch.xpu.init @@ -157,6 +156,13 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 + # getDeviceIdListForCard is renamed since https://github.com/intel/intel-extension-for-pytorch/commit/835b41fd5c8b6facf9efee8312f20699850ee592 + if hasattr(torch.xpu, 'getDeviceIdListForCard'): + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard + else: + torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card + torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() attention_init()