Disable Diffusers slicing if device is not XPU

This commit is contained in:
Disty0
2024-01-02 11:50:08 +03:00
parent 479bac447e
commit 49148eb36e

View File

@@ -20,7 +20,7 @@ def find_slice_size(slice_size, slice_block_size):
return slice_size return slice_size
@cache @cache
def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None): def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3: if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1 shape_four = 1
@@ -40,6 +40,9 @@ def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None)
do_split_2 = False do_split_2 = False
do_split_3 = False do_split_3 = False
if query_device_type != "xpu"
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate: if block_size > attention_slice_rate:
do_split = True do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size) split_slice_size = find_slice_size(split_slice_size, slice_block_size)
@@ -107,7 +110,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
#################################################################### ####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it: # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size) _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size): for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size start_idx = i * split_slice_size
@@ -227,7 +230,7 @@ class AttnProcessor:
# ARC GPUs can't allocate more than 4GB to a single block, Slice it: # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size()) do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split: if do_split:
for i in range(batch_size_attention // split_slice_size): for i in range(batch_size_attention // split_slice_size):