From 15d5e78ac2c865ce45d49ffb89e720721adb0fe7 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 1 Jan 2024 12:44:17 +0300 Subject: [PATCH] Update IPEX Libs --- library/ipex/__init__.py | 1 + library/ipex/attention.py | 188 ++++++++++++--------- library/ipex/diffusers.py | 241 ++++++++++++++++++++++++--- library/ipex/hijacks.py | 340 +++++++++++++++++++------------------- 4 files changed, 489 insertions(+), 281 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index c7854791..33350493 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 2e61f2c9..e98807a8 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,41 +1,98 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = input.element_size() - slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers + +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size block_size = batch_size_attention * slice_block_size split_slice_size = batch_size_attention - if block_size > 4: - do_split = True - # Find something divisible with the input_tokens - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - split_2_slice_size = input_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply - do_split_2 = True - # Find something divisible with the input_tokens - while (split_2_slice_size * slice_block_size_2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - else: - do_split = False + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = input_tokens + split_3_slice_size = mat2_atten_shape + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -44,11 +101,21 @@ def torch_bmm_32_bit(input, mat2, *, out=None): for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) else: hidden_states[start_idx:end_idx] = original_torch_bmm( input[start_idx:end_idx], @@ -61,54 +128,13 @@ def torch_bmm_32_bit(input, mat2, *, out=None): original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_three = query.shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query.shape - - block_multiply = query.element_size() - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - if block_size > 4: - do_split = True - # Find something divisible with the batch_size_attention - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - split_2_slice_size = query_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply - do_split_2 = True - # Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size_2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - split_3_slice_size = shape_three - if split_2_slice_size * slice_block_size_2 > 4: - slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply - do_split_3 = True - # Find something divisible with the shape_three - while (split_3_slice_size * slice_block_size_3) > 4: - split_3_slice_size = split_3_slice_size // 2 - if split_3_slice_size <= 1: - split_3_slice_size = 1 - break - else: - do_split_3 = False - else: - do_split_2 = False - else: - do_split = False + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + # Slice SDPA if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -145,7 +171,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo dropout_p=dropout_p, is_causal=is_causal ) else: - return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal - ) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index c32af507..617c1236 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,10 +1,59 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -18,7 +67,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states input_ndim = hidden_states.ndim @@ -54,49 +105,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = query.element_size() - slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply - block_size = query_tokens * slice_block_size - split_2_slice_size = query_tokens - if block_size > 4: - do_split_2 = True - #Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size if do_split_2: for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - + del query_slice + del key_slice + del attn_mask_slice attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### hidden_states = attn.batch_to_head_dim(hidden_states) @@ -115,6 +178,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods return hidden_states + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size()) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index eb5f779f..93fd7537 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,67 +1,9 @@ import contextlib -import importlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return -class CondFunc: # pylint: disable=missing-class-docstring - def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): - try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) - self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) - -_utils = torch.utils.data._utils -def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: - return - if hasattr(self, "_shutdown") and not self._shutdown: - self._shutdown = True - try: - if hasattr(self, '_pin_memory_thread'): - self._pin_memory_thread_done_event.set() - self._worker_result_queue.put((None, None)) - self._pin_memory_thread.join() - self._worker_result_queue.cancel_join_thread() - self._worker_result_queue.close() - self._workers_done_event.set() - for worker_id in range(len(self._workers)): - if self._persistent_workers or self._workers_status[worker_id]: - self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name - q.cancel_join_thread() - q.close() - finally: - if self._worker_pids_set: - torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) - self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name - if w.is_alive(): - w.terminate() - class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: @@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr def return_null_context(*args, **kwargs): # pylint: disable=unused-argument return contextlib.nullcontext() +@property +def is_cuda(self): + return self.device.type == 'xpu' or self.device.type == 'cuda' + def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" -def ipex_no_cuda(orig_func, *args, **kwargs): - torch.cuda.is_available = lambda: False - orig_func(*args, **kwargs) - torch.cuda.is_available = torch.xpu.is_available +# Autocast original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): if len(args) > 0 and args[0] == "cuda": @@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -# Embedding BF16 -original_torch_cat = torch.cat -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) - -# Latent antialias: +# Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -109,19 +44,19 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) -original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name - if A.device != torch.device("cpu") or B.device != torch.device("cpu"): - return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) +# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): +original_from_numpy = torch.from_numpy +def from_numpy(ndarray): + if ndarray.dtype == float: + return original_from_numpy(ndarray.astype('float32')) else: - return original_linalg_solve(A, B, *args, **kwargs) + return original_from_numpy(ndarray) if torch.xpu.has_fp64_dtype(): original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: - # 64 bit attention workarounds for Alchemist: + # 32 bit attention workarounds for Alchemist: try: from .attention import torch_bmm_32_bit as original_torch_bmm from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention @@ -129,7 +64,8 @@ else: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -# dtype errors: + +# Data Type Errors: def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) @@ -142,111 +78,171 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. value = value.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) -@property -def is_cuda(self): - return self.device.type == 'xpu' +# A1111 FP16 +original_functional_group_norm = torch.nn.functional.group_norm +def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) -def ipex_hijacks(): - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.__init__', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: - orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) - if hasattr(torch.xpu, "Generator"): - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") +# A1111 BF16 +original_functional_layer_norm = torch.nn.functional.layer_norm +def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) + +# Training +original_functional_linear = torch.nn.functional.linear +def functional_linear(input, weight, bias=None): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_linear(input, weight, bias=bias) + +original_functional_conv2d = torch.nn.functional.conv2d +def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +# A1111 Embedding BF16 +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) else: - CondFunc('torch.Generator', - lambda orig_func, device=None: orig_func(return_xpu(device)), - lambda orig_func, device=None: check_device(device)) + return original_torch_cat(tensor, *args, **kwargs) - # TiledVAE and ControlNet: - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) +# SwinIR BF16: +original_funtional_pad = torch.nn.functional.pad +def funtional_pad(input, pad, mode='constant', value=None): + if mode == 'reflect' and input.dtype == torch.bfloat16: + return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + else: + return original_funtional_pad(input, pad, mode=mode, value=value) - # Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # Training: - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # BF16: - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - # SwinIR BF16: - CondFunc('torch.nn.functional.pad', - lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), - lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): - if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) +original_torch_tensor = torch.tensor +def torch_tensor(*args, device=None, **kwargs): + if check_device(device): + return original_torch_tensor(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_tensor(*args, device=device, **kwargs) - # Broken functions when torch.cuda.is_available is True: - # Pin Memory: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', - lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) +original_Tensor_to = torch.Tensor.to +def Tensor_to(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_to(self, device, *args, **kwargs) - # Functions that make compile mad with CondFunc: - torch.nn.DataParallel = DummyDataParallel - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers +original_Tensor_cuda = torch.Tensor.cuda +def Tensor_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_cuda(self, device, *args, **kwargs) + +original_UntypedStorage_init = torch.UntypedStorage.__init__ +def UntypedStorage_init(*args, device=None, **kwargs): + if check_device(device): + return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_init(*args, device=device, **kwargs) + +original_UntypedStorage_cuda = torch.UntypedStorage.cuda +def UntypedStorage_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_UntypedStorage_cuda(self, device, *args, **kwargs) + +original_torch_empty = torch.empty +def torch_empty(*args, device=None, **kwargs): + if check_device(device): + return original_torch_empty(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_empty(*args, device=device, **kwargs) + +original_torch_randn = torch.randn +def torch_randn(*args, device=None, **kwargs): + if check_device(device): + return original_torch_randn(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_randn(*args, device=device, **kwargs) + +original_torch_ones = torch.ones +def torch_ones(*args, device=None, **kwargs): + if check_device(device): + return original_torch_ones(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_ones(*args, device=device, **kwargs) + +original_torch_zeros = torch.zeros +def torch_zeros(*args, device=None, **kwargs): + if check_device(device): + return original_torch_zeros(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_zeros(*args, device=device, **kwargs) + +original_torch_linspace = torch.linspace +def torch_linspace(*args, device=None, **kwargs): + if check_device(device): + return original_torch_linspace(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_linspace(*args, device=device, **kwargs) + +original_torch_Generator = torch.Generator +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +original_torch_load = torch.load +def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): + if check_device(map_location): + return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + else: + return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + +# Hijack Functions: +def ipex_hijacks(): + torch.tensor = torch_tensor + torch.Tensor.to = Tensor_to + torch.Tensor.cuda = Tensor_cuda + torch.UntypedStorage.__init__ = UntypedStorage_init + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.empty = torch_empty + torch.randn = torch_randn + torch.ones = torch_ones + torch.zeros = torch_zeros + torch.linspace = torch_linspace + torch.Generator = torch_Generator + torch.load = torch_load - torch.autocast = ipex_autocast torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda + torch.autocast = ipex_autocast + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + torch.nn.functional.group_norm = functional_group_norm + torch.nn.functional.layer_norm = functional_layer_norm + torch.nn.functional.linear = functional_linear + torch.nn.functional.conv2d = functional_conv2d torch.nn.functional.interpolate = interpolate - torch.linalg.solve = linalg_solve + torch.nn.functional.pad = funtional_pad torch.bmm = torch_bmm torch.cat = torch_cat - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + if not torch.xpu.has_fp64_dtype(): + torch.from_numpy = from_numpy