diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 617c1236..56528551 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -20,7 +20,7 @@ def find_slice_size(slice_size, slice_block_size): return slice_size @cache -def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None): +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): if len(query_shape) == 3: batch_size_attention, query_tokens, shape_three = query_shape shape_four = 1 @@ -40,6 +40,9 @@ def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None) do_split_2 = False do_split_3 = False + if query_device_type != "xpu" + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + if block_size > attention_slice_rate: do_split = True split_slice_size = find_slice_size(split_slice_size, slice_block_size) @@ -107,7 +110,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods #################################################################### # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size) + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -227,7 +230,7 @@ class AttnProcessor: # ARC GPUs can't allocate more than 4GB to a single block, Slice it: batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size()) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) if do_split: for i in range(batch_size_attention // split_slice_size):