IPEX attention optimizations

This commit is contained in:
Disty0
2023-09-28 14:02:25 +03:00
parent 1e395ed285
commit 209eafb631
3 changed files with 25 additions and 20 deletions

View File

@@ -55,13 +55,14 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size >= 4000:
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000:
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1