diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 662572c8..c7854791 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -156,20 +156,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - if hasattr(torch.xpu, 'getDeviceIdListForCard'): - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard - else: - torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card - torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() if not torch.xpu.has_fp64_dtype(): - try: - from .attention import attention_init - attention_init() - except Exception: # pylint: disable=broad-exception-caught - pass try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 52016466..ced59637 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -4,11 +4,8 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unuse # pylint: disable=protected-access, missing-function-docstring, line-too-long original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: +def torch_bmm_32_bit(input, mat2, *, out=None): + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] block_multiply = input.element_size() slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply @@ -17,7 +14,7 @@ def torch_bmm(input, mat2, *, out=None): split_slice_size = batch_size_attention if block_size > 4: do_split = True - #Find something divisible with the input_tokens + # Find something divisible with the input_tokens while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: @@ -30,7 +27,7 @@ def torch_bmm(input, mat2, *, out=None): if split_slice_size * slice_block_size > 4: slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply do_split_2 = True - #Find something divisible with the input_tokens + # Find something divisible with the input_tokens while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: @@ -64,8 +61,8 @@ def torch_bmm(input, mat2, *, out=None): return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: if len(query.shape) == 3: batch_size_attention, query_tokens, shape_four = query.shape shape_one = 1 @@ -74,11 +71,6 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - block_multiply = query.element_size() slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply block_size = batch_size_attention * slice_block_size @@ -86,7 +78,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. split_slice_size = batch_size_attention if block_size > 4: do_split = True - #Find something divisible with the shape_one + # Find something divisible with the shape_one while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: @@ -99,7 +91,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. if split_slice_size * slice_block_size > 4: slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply do_split_2 = True - #Find something divisible with the batch_size_attention + # Find something divisible with the batch_size_attention while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: @@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal ) return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 4a9a3569..a699e1e4 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -117,6 +117,31 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +if torch.xpu.has_fp64_dtype(): + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 64 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + +# dtype errors: +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + @property def is_cuda(self): return self.device.type == 'xpu' @@ -156,10 +181,10 @@ def ipex_hijacks(): lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) - - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + if hasattr(torch.xpu, "Generator"): + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") # TiledVAE and ControlNet: CondFunc('torch.batch_norm', @@ -208,11 +233,16 @@ def ipex_hijacks(): lambda orig_func, *args, **kwargs: True) # Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers torch.nn.DataParallel = DummyDataParallel + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve - torch.UntypedStorage.is_cuda = is_cuda - torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context + torch.UntypedStorage.is_cuda = is_cuda + + torch.nn.functional.interpolate = interpolate + torch.linalg.solve = linalg_solve + + torch.bmm = torch_bmm + torch.cat = torch_cat + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention