diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9..d2e19745 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -2,6 +2,13 @@ import functools import gc import torch +try: + # intel gpu support for pytorch older than 2.5 + # ipex is not needed after pytorch 2.5 + import intel_extension_for_pytorch as ipex # noqa +except Exception: + pass + try: HAS_CUDA = torch.cuda.is_available() @@ -14,8 +21,6 @@ except Exception: HAS_MPS = False try: - import intel_extension_for_pytorch as ipex # noqa - HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -69,7 +74,7 @@ def init_ipex(): This function should run right after importing torch and before doing anything else. - If IPEX is not available, this function does nothing. + If xpu is not available, this function does nothing. """ try: if HAS_XPU: diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index e5aba693..a36664bb 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -2,7 +2,11 @@ import os import sys import contextlib import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +try: + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + legacy = True +except Exception: + legacy = False from .hijacks import ipex_hijacks # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: return True, "Skipping IPEX hijack" else: + try: # force xpu device on torch compile and triton + torch._inductor.utils.GPU_TYPES = ["xpu"] + torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu" + from triton import backends as triton_backends # pylint: disable=import-error + triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False + except Exception: + pass # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream @@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_current_stream_capturing = lambda: False torch.cuda.set_device = torch.xpu.set_device torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize torch.cuda.Event = torch.xpu.Event torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu torch.nn.Module.cuda = torch.nn.Module.xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback torch.cuda.Optional = torch.xpu.Optional torch.cuda.__cached__ = torch.xpu.__cached__ torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage torch.cuda.Any = torch.xpu.Any torch.cuda.__doc__ = torch.xpu.__doc__ torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor torch.cuda._get_device_index = torch.xpu._get_device_index torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor torch.cuda.List = torch.xpu.List torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage torch.cuda.random = torch.xpu.random torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.__name__ = torch.xpu.__name__ torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if legacy: + torch.cuda.os = torch.xpu.os + torch.cuda.Device = torch.xpu.Device + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.classproperty = torch.xpu.classproperty + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + if float(ipex.__version__[:3]) < 2.3: + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda._lazy_new = torch.xpu._lazy_new + + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + + if not legacy or float(ipex.__version__[:3]) >= 2.3: + torch.cuda._initialization_lock = torch.xpu._initialization_lock + torch.cuda._initialized = torch.xpu._initialized + torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu._queued_calls + torch.cuda._tls = torch.xpu._tls + torch.cuda.threading = torch.xpu.threading + torch.cuda.traceback = torch.xpu.traceback + # Memory: - torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache + + if legacy: + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory = torch.xpu.memory torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot torch.cuda.memory_allocated = torch.xpu.memory_allocated torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated torch.cuda.memory_reserved = torch.xpu.memory_reserved @@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.initial_seed = torch.xpu.initial_seed # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if legacy: + torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd + torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd + torch.cuda.amp = torch.xpu.amp + if float(ipex.__version__[:3]) < 2.3: + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count - ipex._C._DeviceProperties.major = 2024 - ipex._C._DeviceProperties.minor = 0 + if legacy and float(ipex.__version__[:3]) < 2.3: + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 12 + ipex._C._DeviceProperties.minor = 1 + else: + torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream + torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count + torch._C._XpuDeviceProperties.major = 12 + torch._C._XpuDeviceProperties.minor = 1 # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + # torch.xpu.mem_get_info always returns the total memory as free memory + torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch.cuda.mem_get_info = torch.xpu.mem_get_info torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True @@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.backends.cuda.is_built = lambda *args, **kwargs: True torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"] + torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1) torch.cuda.get_device_properties.major = 12 torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy) + try: + from .diffusers import ipex_diffusers + ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb) + except Exception: # pylint: disable=broad-exception-caught + pass torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 2bc62f65..400b59b6 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,177 +1,119 @@ import os import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from functools import cache +from functools import cache, wraps # pylint: disable=protected-access, missing-function-docstring, line-too-long # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers -sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5)) # Find something divisible with the input_tokens @cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size +def find_split_size(original_size, slice_block_size, slice_rate=2): + split_size = original_size + while True: + if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0: + return split_size + split_size = split_size - 1 + if split_size <= 1: + return 1 + return split_size + # Find slice sizes for SDPA @cache -def find_sdpa_slice_sizes(query_shape, query_element_size): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape +def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3): + batch_size, attn_heads, query_len, _ = query_shape + _, _, key_len, _ = key_shape - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size + slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three + split_batch_size = batch_size + split_head_size = attn_heads + split_query_size = query_len - do_split = False - do_split_2 = False - do_split_3 = False + do_batch_split = False + do_head_split = False + do_query_split = False - if block_size > sdpa_slice_trigger_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + if batch_size * slice_batch_size >= trigger_rate: + do_batch_split = True + split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate) - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + if split_batch_size * slice_batch_size > slice_rate: + slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 + do_head_split = True + split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate) -# Find slice sizes for BMM -@cache -def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): - batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] - slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size - block_size = batch_size_attention * slice_block_size + if split_head_size * slice_head_size > slice_rate: + slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024 + do_query_split = True + split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate) - split_slice_size = batch_size_attention - split_2_slice_size = input_tokens - split_3_slice_size = mat2_atten_shape + return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - if input.device.type != "xpu": - return original_torch_bmm(input, mat2, out=out) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) - - # Slice BMM - if do_split: - batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - out=out - ) - else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) - else: - hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out - ) - torch.xpu.synchronize(input.device) - else: - return original_torch_bmm(input, mat2, out=out) - return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): +@wraps(torch.nn.functional.scaled_dot_product_attention) +def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.device.type != "xpu": return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + is_unsqueezed = False + if len(query.shape) == 3: + query = query.unsqueeze(0) + is_unsqueezed = True + if len(key.shape) == 3: + key = key.unsqueeze(0) + if len(value.shape) == 3: + value = value.unsqueeze(0) + do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate) # Slice SDPA - if do_split: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + if do_batch_split: + batch_size, attn_heads, query_len, _ = query.shape + _, _, _, head_dim = value.shape + hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype) + if attn_mask is not None: + attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2])) + for ib in range(batch_size // split_batch_size): + start_idx = ib * split_batch_size + end_idx = (ib + 1) * split_batch_size + if do_head_split: + for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name + start_idx_h = ih * split_head_size + end_idx_h = (ih + 1) * split_head_size + if do_query_split: + for iq in range(query_len // split_query_size): # pylint: disable=invalid-name + start_idx_q = iq * split_query_size + end_idx_q = (iq + 1) * split_query_size + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2], - key[start_idx:end_idx, start_idx_2:end_idx_2], - value[start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, :, :, :], + key[start_idx:end_idx, :, :, :], + value[start_idx:end_idx, :, :, :], + attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + if is_unsqueezed: + hidden_states.squeeze(0) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 732a1856..75715d16 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,312 +1,47 @@ -import os +from functools import wraps import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.24.0 # pylint: disable=import-error -from diffusers.models.attention_processor import Attention -from diffusers.utils import USE_PEFT_BACKEND -from functools import cache +import diffusers # pylint: disable=import-error # pylint: disable=protected-access, missing-function-docstring, line-too-long -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) -@cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size - -@cache -def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - if slice_size is not None: - batch_size_attention = slice_size - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if query_device_type != "xpu": - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -class SlicedAttnProcessor: # pylint: disable=too-few-public-methods - r""" - Processor for implementing sliced attention. - - Args: - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. - """ - - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, shape_three = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) - - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - #################################################################### - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None, - temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) - - if do_split: - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - #################################################################### - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -def ipex_diffusers(): - #ARC GPUs can't allocate more than 4GB to a single block: - diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor = AttnProcessor +# Diffusers FreeU +original_fourier_filter = diffusers.utils.torch_utils.fourier_filter +@wraps(diffusers.utils.torch_utils.fourier_filter) +def fourier_filter(x_in, threshold, scale): + return_dtype = x_in.dtype + return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype) + + +# fp64 error +class FluxPosEmbed(torch.nn.Module): + def __init__(self, theta: int, axes_dim): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + for i in range(n_axes): + cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=torch.float32, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): + diffusers.utils.torch_utils.fourier_filter = fourier_filter + if not device_supports_fp64: + diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 6eb56bc2..0a861009 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index d3cef827..91569746 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -2,10 +2,19 @@ import os from functools import wraps from contextlib import nullcontext import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 +if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1: + try: + x = torch.ones((33000,33000), dtype=torch.float32, device="xpu") + del x + torch.xpu.empty_cache() + can_allocate_plus_4gb = True + except Exception: + can_allocate_plus_4gb = False +else: + can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1') # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -26,7 +35,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu" # Autocast @@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None or mode == 'bicubic': + if mode in {'bicubic', 'bilinear'}: return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None): return original_as_tensor(data, dtype=dtype, device=device) -if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: - original_torch_bmm = torch.bmm +if can_allocate_plus_4gb: original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: # 32 bit attention workarounds for Alchemist: try: - from .attention import torch_bmm_32_bit as original_torch_bmm - from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention except Exception: # pylint: disable=broad-exception-caught - original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention - -# Data Type Errors: -@wraps(torch.bmm) -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - return original_torch_bmm(input, mat2, out=out) - @wraps(torch.nn.functional.scaled_dot_product_attention) -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) if attn_mask is not None and query.dtype != attn_mask.dtype: attn_mask = attn_mask.to(dtype=query.dtype) - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + +# Data Type Errors: +original_torch_bmm = torch.bmm +@wraps(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +# Diffusers FreeU +original_fft_fftn = torch.fft.fftn +@wraps(torch.fft.fftn) +def fft_fftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) + +# Diffusers FreeU +original_fft_ifftn = torch.fft.ifftn +@wraps(torch.fft.ifftn) +def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm @@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None): bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_linear(input, weight, bias=bias) +original_functional_conv1d = torch.nn.functional.conv1d +@wraps(torch.nn.functional.conv1d) +def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + original_functional_conv2d = torch.nn.functional.conv2d @wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) -# A1111 Embedding BF16 -original_torch_cat = torch.cat -@wraps(torch.cat) -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) +# LTX Video +original_functional_conv3d = torch.nn.functional.conv3d +@wraps(torch.nn.functional.conv3d) +def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) # SwinIR BF16: original_functional_pad = torch.nn.functional.pad @@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) def torch_tensor(data, *args, dtype=None, device=None, **kwargs): + global device_supports_fp64 if check_device(device): device = return_xpu(device) if not device_supports_fp64: @@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs): original_torch_randn = torch.randn @wraps(torch.randn) def torch_randn(*args, device=None, dtype=None, **kwargs): - if dtype == bytes: + if dtype is bytes: dtype = None if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs): else: return original_torch_zeros(*args, device=device, **kwargs) +original_torch_full = torch.full +@wraps(torch.full) +def torch_full(*args, device=None, **kwargs): + if check_device(device): + return original_torch_full(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_full(*args, device=device, **kwargs) + original_torch_linspace = torch.linspace @wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): @@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs): else: return original_torch_linspace(*args, device=device, **kwargs) -original_torch_Generator = torch.Generator -@wraps(torch.Generator) -def torch_Generator(device=None): - if check_device(device): - return original_torch_Generator(return_xpu(device)) - else: - return original_torch_Generator(device) - original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): @@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs): else: return original_torch_load(f, *args, map_location=map_location, **kwargs) +original_torch_Generator = torch.Generator +@wraps(torch.Generator) +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +@wraps(torch.cuda.synchronize) +def torch_cuda_synchronize(device=None): + if check_device(device): + return torch.xpu.synchronize(return_xpu(device)) + else: + return torch.xpu.synchronize(device) + # Hijack Functions: -def ipex_hijacks(): +def ipex_hijacks(legacy=True): + global device_supports_fp64, can_allocate_plus_4gb + if legacy and float(torch.__version__[:3]) < 2.5: + torch.nn.functional.interpolate = interpolate torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda @@ -289,9 +338,11 @@ def ipex_hijacks(): torch.randn = torch_randn torch.ones = torch_ones torch.zeros = torch_zeros + torch.full = torch_full torch.linspace = torch_linspace - torch.Generator = torch_Generator torch.load = torch_load + torch.Generator = torch_Generator + torch.cuda.synchronize = torch_cuda_synchronize torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel @@ -302,12 +353,15 @@ def ipex_hijacks(): torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.layer_norm = functional_layer_norm torch.nn.functional.linear = functional_linear + torch.nn.functional.conv1d = functional_conv1d torch.nn.functional.conv2d = functional_conv2d - torch.nn.functional.interpolate = interpolate + torch.nn.functional.conv3d = functional_conv3d torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm - torch.cat = torch_cat + torch.fft.fftn = fft_fftn + torch.fft.ifftn = fft_ifftn if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor + return device_supports_fp64, can_allocate_plus_4gb