Update IPEX hijacks

This commit is contained in:
Disty0
2023-12-05 14:17:31 +03:00
parent 46cf41cc93
commit bce9a081db
3 changed files with 34 additions and 7 deletions

View File

@@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size