Update IPEX libs

This commit is contained in:
Disty0
2025-02-25 21:27:41 +03:00
parent 386b7332c6
commit f68702f71c
6 changed files with 330 additions and 556 deletions

View File

@@ -2,7 +2,11 @@ import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
legacy = True
except Exception:
legacy = False
from .hijacks import ipex_hijacks
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
try: # force xpu device on torch compile and triton
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
from triton import backends as triton_backends # pylint: disable=import-error
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
except Exception:
pass
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if legacy:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if float(ipex.__version__[:3]) < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
if not legacy or float(ipex.__version__[:3]) >= 2.3:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu._queued_calls
torch.cuda._tls = torch.xpu._tls
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
if legacy:
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory = torch.xpu.memory
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
@@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if legacy:
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp
if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
if legacy and float(ipex.__version__[:3]) < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
ipex._C._DeviceProperties.minor = 1
else:
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
torch._C._XpuDeviceProperties.major = 12
torch._C._XpuDeviceProperties.minor = 1
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
# torch.xpu.mem_get_info always returns the total memory as free memory
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
@@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
try:
from .diffusers import ipex_diffusers
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e

View File

@@ -1,177 +1,119 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
from functools import cache, wraps
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
def find_split_size(original_size, slice_block_size, slice_rate=2):
split_size = original_size
while True:
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
return split_size
split_size = split_size - 1
if split_size <= 1:
return 1
return split_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
batch_size, attn_heads, query_len, _ = query_shape
_, _, key_len, _ = key_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
split_batch_size = batch_size
split_head_size = attn_heads
split_query_size = query_len
do_split = False
do_split_2 = False
do_split_3 = False
do_batch_split = False
do_head_split = False
do_query_split = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
if batch_size * slice_batch_size >= trigger_rate:
do_batch_split = True
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if split_batch_size * slice_batch_size > slice_rate:
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
do_head_split = True
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
if split_head_size * slice_head_size > slice_rate:
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
do_query_split = True
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
split_3_slice_size = mat2_atten_shape
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
@wraps(torch.nn.functional.scaled_dot_product_attention)
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
is_unsqueezed = False
if len(query.shape) == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if len(key.shape) == 3:
key = key.unsqueeze(0)
if len(value.shape) == 3:
value = value.unsqueeze(0)
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
if do_batch_split:
batch_size, attn_heads, query_len, _ = query.shape
_, _, _, head_dim = value.shape
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
if attn_mask is not None:
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
for ib in range(batch_size // split_batch_size):
start_idx = ib * split_batch_size
end_idx = (ib + 1) * split_batch_size
if do_head_split:
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
start_idx_h = ih * split_head_size
end_idx_h = (ih + 1) * split_head_size
if do_query_split:
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
start_idx_q = iq * split_query_size
end_idx_q = (iq + 1) * split_query_size
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, :, :, :],
key[start_idx:end_idx, :, :, :],
value[start_idx:end_idx, :, :, :],
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
if is_unsqueezed:
hidden_states.squeeze(0)
return hidden_states

View File

@@ -1,312 +1,47 @@
import os
from functools import wraps
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
import diffusers # pylint: disable=import-error
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
# Diffusers FreeU
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
@wraps(diffusers.utils.torch_utils.fourier_filter)
def fourier_filter(x_in, threshold, scale):
return_dtype = x_in.dtype
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
# fp64 error
class FluxPosEmbed(torch.nn.Module):
def __init__(self, theta: int, axes_dim):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
for i in range(n_axes):
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=torch.float32,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
diffusers.utils.torch_utils.fourier_filter = fourier_filter
if not device_supports_fp64:
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed

View File

@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state

View File

@@ -2,10 +2,19 @@ import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
del x
torch.xpu.empty_cache()
can_allocate_plus_4gb = True
except Exception:
can_allocate_plus_4gb = False
else:
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -26,7 +35,7 @@ def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
# Autocast
@@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None or mode == 'bicubic':
if mode in {'bicubic', 'bilinear'}:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None):
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
if can_allocate_plus_4gb:
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
# Data Type Errors:
original_torch_bmm = torch.bmm
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
# Diffusers FreeU
original_fft_fftn = torch.fft.fftn
@wraps(torch.fft.fftn)
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# Diffusers FreeU
original_fft_ifftn = torch.fft.ifftn
@wraps(torch.fft.ifftn)
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None):
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv1d = torch.nn.functional.conv1d
@wraps(torch.nn.functional.conv1d)
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
@@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# LTX Video
original_functional_conv3d = torch.nn.functional.conv3d
@wraps(torch.nn.functional.conv3d)
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_device(device):
device = return_xpu(device)
if not device_supports_fp64:
@@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
if dtype is bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs):
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_full = torch.full
@wraps(torch.full)
def torch_full(*args, device=None, **kwargs):
if check_device(device):
return original_torch_full(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_full(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
@@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs):
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
@@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs):
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
@wraps(torch.cuda.synchronize)
def torch_cuda_synchronize(device=None):
if check_device(device):
return torch.xpu.synchronize(return_xpu(device))
else:
return torch.xpu.synchronize(device)
# Hijack Functions:
def ipex_hijacks():
def ipex_hijacks(legacy=True):
global device_supports_fp64, can_allocate_plus_4gb
if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
@@ -289,9 +338,11 @@ def ipex_hijacks():
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.full = torch_full
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
torch.Generator = torch_Generator
torch.cuda.synchronize = torch_cuda_synchronize
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
@@ -302,12 +353,15 @@ def ipex_hijacks():
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv1d = functional_conv1d
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.conv3d = functional_conv3d
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor
return device_supports_fp64, can_allocate_plus_4gb