diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 9ec69012..19ec8eea 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -16,7 +16,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.device = torch.xpu.device torch.cuda.device_count = torch.xpu.device_count torch.cuda.device_of = torch.xpu.device_of - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard torch.cuda.get_device_name = torch.xpu.get_device_name torch.cuda.get_device_properties = torch.xpu.get_device_properties torch.cuda.init = torch.xpu.init @@ -157,6 +156,13 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 + # getDeviceIdListForCard is renamed since https://github.com/intel/intel-extension-for-pytorch/commit/835b41fd5c8b6facf9efee8312f20699850ee592 + if hasattr(torch.xpu, 'getDeviceIdListForCard'): + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard + else: + torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card + torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() attention_init()