mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update IPEX Libs
This commit is contained in:
@@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
|
|
||||||
# C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
|
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2023
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 2
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +1,98 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
original_torch_bmm = torch.bmm
|
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||||
block_multiply = input.element_size()
|
|
||||||
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
# Find something divisible with the input_tokens
|
||||||
|
@cache
|
||||||
|
def find_slice_size(slice_size, slice_block_size):
|
||||||
|
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||||
|
slice_size = slice_size // 2
|
||||||
|
if slice_size <= 1:
|
||||||
|
slice_size = 1
|
||||||
|
break
|
||||||
|
return slice_size
|
||||||
|
|
||||||
|
# Find slice sizes for SDPA
|
||||||
|
@cache
|
||||||
|
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||||
|
if len(query_shape) == 3:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query_shape
|
||||||
|
shape_four = 1
|
||||||
|
else:
|
||||||
|
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||||
|
|
||||||
|
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
block_size = batch_size_attention * slice_block_size
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
split_slice_size = batch_size_attention
|
||||||
if block_size > 4:
|
split_2_slice_size = query_tokens
|
||||||
do_split = True
|
split_3_slice_size = shape_three
|
||||||
# Find something divisible with the input_tokens
|
|
||||||
while (split_slice_size * slice_block_size) > 4:
|
|
||||||
split_slice_size = split_slice_size // 2
|
|
||||||
if split_slice_size <= 1:
|
|
||||||
split_slice_size = 1
|
|
||||||
break
|
|
||||||
split_2_slice_size = input_tokens
|
|
||||||
if split_slice_size * slice_block_size > 4:
|
|
||||||
slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
|
||||||
do_split_2 = True
|
|
||||||
# Find something divisible with the input_tokens
|
|
||||||
while (split_2_slice_size * slice_block_size_2) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
else:
|
|
||||||
do_split = False
|
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if block_size > sdpa_slice_trigger_rate:
|
||||||
|
do_split = True
|
||||||
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
|
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_2 = True
|
||||||
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
|
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
|
# Find slice sizes for BMM
|
||||||
|
@cache
|
||||||
|
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||||
|
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||||
|
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||||
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
|
split_2_slice_size = input_tokens
|
||||||
|
split_3_slice_size = mat2_atten_shape
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if block_size > attention_slice_rate:
|
||||||
|
do_split = True
|
||||||
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
|
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||||
|
do_split_2 = True
|
||||||
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
|
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||||
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
|
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||||
|
if input.device.type != "xpu":
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||||
|
|
||||||
|
# Slice BMM
|
||||||
if do_split:
|
if do_split:
|
||||||
|
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
@@ -44,6 +101,16 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
|||||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
if do_split_3:
|
||||||
|
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||||
|
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
else:
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
@@ -61,54 +128,13 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
|||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
if query.device.type != "xpu":
|
||||||
if len(query.shape) == 3:
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
batch_size_attention, query_tokens, shape_three = query.shape
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||||
shape_four = 1
|
|
||||||
else:
|
|
||||||
batch_size_attention, query_tokens, shape_three, shape_four = query.shape
|
|
||||||
|
|
||||||
block_multiply = query.element_size()
|
|
||||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
|
|
||||||
block_size = batch_size_attention * slice_block_size
|
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
|
||||||
if block_size > 4:
|
|
||||||
do_split = True
|
|
||||||
# Find something divisible with the batch_size_attention
|
|
||||||
while (split_slice_size * slice_block_size) > 4:
|
|
||||||
split_slice_size = split_slice_size // 2
|
|
||||||
if split_slice_size <= 1:
|
|
||||||
split_slice_size = 1
|
|
||||||
break
|
|
||||||
split_2_slice_size = query_tokens
|
|
||||||
if split_slice_size * slice_block_size > 4:
|
|
||||||
slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
|
|
||||||
do_split_2 = True
|
|
||||||
# Find something divisible with the query_tokens
|
|
||||||
while (split_2_slice_size * slice_block_size_2) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
split_3_slice_size = shape_three
|
|
||||||
if split_2_slice_size * slice_block_size_2 > 4:
|
|
||||||
slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
|
|
||||||
do_split_3 = True
|
|
||||||
# Find something divisible with the shape_three
|
|
||||||
while (split_3_slice_size * slice_block_size_3) > 4:
|
|
||||||
split_3_slice_size = split_3_slice_size // 2
|
|
||||||
if split_3_slice_size <= 1:
|
|
||||||
split_3_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_3 = False
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
else:
|
|
||||||
do_split = False
|
|
||||||
|
|
||||||
|
# Slice SDPA
|
||||||
if do_split:
|
if do_split:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
@@ -145,7 +171,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
|||||||
dropout_p=dropout_p, is_causal=is_causal
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return original_scaled_dot_product_attention(
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
|
||||||
)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -1,10 +1,59 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import diffusers #0.24.0 # pylint: disable=import-error
|
import diffusers #0.24.0 # pylint: disable=import-error
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_slice_size(slice_size, slice_block_size):
|
||||||
|
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||||
|
slice_size = slice_size // 2
|
||||||
|
if slice_size <= 1:
|
||||||
|
slice_size = 1
|
||||||
|
break
|
||||||
|
return slice_size
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None):
|
||||||
|
if len(query_shape) == 3:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query_shape
|
||||||
|
shape_four = 1
|
||||||
|
else:
|
||||||
|
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||||
|
if slice_size is not None:
|
||||||
|
batch_size_attention = slice_size
|
||||||
|
|
||||||
|
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
|
split_2_slice_size = query_tokens
|
||||||
|
split_3_slice_size = shape_three
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if block_size > attention_slice_rate:
|
||||||
|
do_split = True
|
||||||
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
|
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_2 = True
|
||||||
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
|
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing sliced attention.
|
Processor for implementing sliced attention.
|
||||||
@@ -18,7 +67,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
def __init__(self, slice_size):
|
def __init__(self, slice_size):
|
||||||
self.slice_size = slice_size
|
self.slice_size = slice_size
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
input_ndim = hidden_states.ndim
|
||||||
@@ -54,49 +105,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
####################################################################
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
block_multiply = query.element_size()
|
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size)
|
||||||
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
|
|
||||||
block_size = query_tokens * slice_block_size
|
|
||||||
split_2_slice_size = query_tokens
|
|
||||||
if block_size > 4:
|
|
||||||
do_split_2 = True
|
|
||||||
#Find something divisible with the query_tokens
|
|
||||||
while (split_2_slice_size * slice_block_size) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
|
|
||||||
for i in range(batch_size_attention // self.slice_size):
|
|
||||||
start_idx = i * self.slice_size
|
|
||||||
end_idx = (i + 1) * self.slice_size
|
|
||||||
|
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
if do_split_2:
|
if do_split_2:
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
if do_split_3:
|
||||||
|
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||||
|
del attn_slice
|
||||||
else:
|
else:
|
||||||
query_slice = query[start_idx:end_idx]
|
query_slice = query[start_idx:end_idx]
|
||||||
key_slice = key[start_idx:end_idx]
|
key_slice = key[start_idx:end_idx]
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
####################################################################
|
||||||
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
@@ -115,6 +178,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor:
|
||||||
|
r"""
|
||||||
|
Default processor for performing attention-related computations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states=None, attention_mask=None,
|
||||||
|
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
args = () if USE_PEFT_BACKEND else (scale,)
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states, *args)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states, *args)
|
||||||
|
value = attn.to_v(encoder_hidden_states, *args)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
####################################################################
|
||||||
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||||
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size())
|
||||||
|
|
||||||
|
if do_split:
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
|
if do_split_2:
|
||||||
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
if do_split_3:
|
||||||
|
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
query_slice = query[start_idx:end_idx]
|
||||||
|
key_slice = key[start_idx:end_idx]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
####################################################################
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def ipex_diffusers():
|
def ipex_diffusers():
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||||
|
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||||
|
|||||||
@@ -1,67 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||||
|
|
||||||
class CondFunc: # pylint: disable=missing-class-docstring
|
|
||||||
def __new__(cls, orig_func, sub_func, cond_func):
|
|
||||||
self = super(CondFunc, cls).__new__(cls)
|
|
||||||
if isinstance(orig_func, str):
|
|
||||||
func_path = orig_func.split('.')
|
|
||||||
for i in range(len(func_path)-1, -1, -1):
|
|
||||||
try:
|
|
||||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
|
||||||
break
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
for attr_name in func_path[i:-1]:
|
|
||||||
resolved_obj = getattr(resolved_obj, attr_name)
|
|
||||||
orig_func = getattr(resolved_obj, func_path[-1])
|
|
||||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
|
||||||
self.__init__(orig_func, sub_func, cond_func)
|
|
||||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
|
||||||
def __init__(self, orig_func, sub_func, cond_func):
|
|
||||||
self.__orig_func = orig_func
|
|
||||||
self.__sub_func = sub_func
|
|
||||||
self.__cond_func = cond_func
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
|
||||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.__orig_func(*args, **kwargs)
|
|
||||||
|
|
||||||
_utils = torch.utils.data._utils
|
|
||||||
def _shutdown_workers(self):
|
|
||||||
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
|
||||||
return
|
|
||||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
|
||||||
self._shutdown = True
|
|
||||||
try:
|
|
||||||
if hasattr(self, '_pin_memory_thread'):
|
|
||||||
self._pin_memory_thread_done_event.set()
|
|
||||||
self._worker_result_queue.put((None, None))
|
|
||||||
self._pin_memory_thread.join()
|
|
||||||
self._worker_result_queue.cancel_join_thread()
|
|
||||||
self._worker_result_queue.close()
|
|
||||||
self._workers_done_event.set()
|
|
||||||
for worker_id in range(len(self._workers)):
|
|
||||||
if self._persistent_workers or self._workers_status[worker_id]:
|
|
||||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
|
||||||
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
|
||||||
for q in self._index_queues: # pylint: disable=invalid-name
|
|
||||||
q.cancel_join_thread()
|
|
||||||
q.close()
|
|
||||||
finally:
|
|
||||||
if self._worker_pids_set:
|
|
||||||
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
|
||||||
self._worker_pids_set = False
|
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
|
||||||
if w.is_alive():
|
|
||||||
w.terminate()
|
|
||||||
|
|
||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||||
@@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
|||||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_cuda(self):
|
||||||
|
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||||
|
|
||||||
def check_device(device):
|
def check_device(device):
|
||||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||||
|
|
||||||
def return_xpu(device):
|
def return_xpu(device):
|
||||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||||
|
|
||||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
|
||||||
torch.cuda.is_available = lambda: False
|
|
||||||
orig_func(*args, **kwargs)
|
|
||||||
torch.cuda.is_available = torch.xpu.is_available
|
|
||||||
|
|
||||||
|
# Autocast
|
||||||
original_autocast = torch.autocast
|
original_autocast = torch.autocast
|
||||||
def ipex_autocast(*args, **kwargs):
|
def ipex_autocast(*args, **kwargs):
|
||||||
if len(args) > 0 and args[0] == "cuda":
|
if len(args) > 0 and args[0] == "cuda":
|
||||||
@@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_autocast(*args, **kwargs)
|
return original_autocast(*args, **kwargs)
|
||||||
|
|
||||||
# Embedding BF16
|
# Latent Antialias CPU Offload:
|
||||||
original_torch_cat = torch.cat
|
|
||||||
def torch_cat(tensor, *args, **kwargs):
|
|
||||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
|
||||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return original_torch_cat(tensor, *args, **kwargs)
|
|
||||||
|
|
||||||
# Latent antialias:
|
|
||||||
original_interpolate = torch.nn.functional.interpolate
|
original_interpolate = torch.nn.functional.interpolate
|
||||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||||
if antialias or align_corners is not None:
|
if antialias or align_corners is not None:
|
||||||
@@ -109,19 +44,19 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
|||||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||||
|
|
||||||
original_linalg_solve = torch.linalg.solve
|
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
original_from_numpy = torch.from_numpy
|
||||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
def from_numpy(ndarray):
|
||||||
return_device = A.device
|
if ndarray.dtype == float:
|
||||||
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
return original_from_numpy(ndarray.astype('float32'))
|
||||||
else:
|
else:
|
||||||
return original_linalg_solve(A, B, *args, **kwargs)
|
return original_from_numpy(ndarray)
|
||||||
|
|
||||||
if torch.xpu.has_fp64_dtype():
|
if torch.xpu.has_fp64_dtype():
|
||||||
original_torch_bmm = torch.bmm
|
original_torch_bmm = torch.bmm
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
else:
|
else:
|
||||||
# 64 bit attention workarounds for Alchemist:
|
# 32 bit attention workarounds for Alchemist:
|
||||||
try:
|
try:
|
||||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||||
@@ -129,7 +64,8 @@ else:
|
|||||||
original_torch_bmm = torch.bmm
|
original_torch_bmm = torch.bmm
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
# dtype errors:
|
|
||||||
|
# Data Type Errors:
|
||||||
def torch_bmm(input, mat2, *, out=None):
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
if input.dtype != mat2.dtype:
|
if input.dtype != mat2.dtype:
|
||||||
mat2 = mat2.to(input.dtype)
|
mat2 = mat2.to(input.dtype)
|
||||||
@@ -142,111 +78,171 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
value = value.to(dtype=query.dtype)
|
value = value.to(dtype=query.dtype)
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
|
|
||||||
@property
|
# A1111 FP16
|
||||||
def is_cuda(self):
|
original_functional_group_norm = torch.nn.functional.group_norm
|
||||||
return self.device.type == 'xpu'
|
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||||
|
if weight is not None and input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
||||||
|
|
||||||
def ipex_hijacks():
|
# A1111 BF16
|
||||||
CondFunc('torch.tensor',
|
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
if weight is not None and input.dtype != weight.data.dtype:
|
||||||
CondFunc('torch.Tensor.to',
|
input = input.to(dtype=weight.data.dtype)
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
CondFunc('torch.Tensor.cuda',
|
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
# Training
|
||||||
CondFunc('torch.UntypedStorage.__init__',
|
original_functional_linear = torch.nn.functional.linear
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
def functional_linear(input, weight, bias=None):
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
if input.dtype != weight.data.dtype:
|
||||||
CondFunc('torch.UntypedStorage.cuda',
|
input = input.to(dtype=weight.data.dtype)
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
CondFunc('torch.empty',
|
return original_functional_linear(input, weight, bias=bias)
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
original_functional_conv2d = torch.nn.functional.conv2d
|
||||||
CondFunc('torch.randn',
|
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
if input.dtype != weight.data.dtype:
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
input = input.to(dtype=weight.data.dtype)
|
||||||
CondFunc('torch.ones',
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
CondFunc('torch.zeros',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
# A1111 Embedding BF16
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
original_torch_cat = torch.cat
|
||||||
CondFunc('torch.linspace',
|
def torch_cat(tensor, *args, **kwargs):
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||||
CondFunc('torch.load',
|
|
||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
|
||||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
|
||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
|
||||||
if hasattr(torch.xpu, "Generator"):
|
|
||||||
CondFunc('torch.Generator',
|
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
|
||||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
|
||||||
else:
|
else:
|
||||||
CondFunc('torch.Generator',
|
return original_torch_cat(tensor, *args, **kwargs)
|
||||||
lambda orig_func, device=None: orig_func(return_xpu(device)),
|
|
||||||
lambda orig_func, device=None: check_device(device))
|
|
||||||
|
|
||||||
# TiledVAE and ControlNet:
|
|
||||||
CondFunc('torch.batch_norm',
|
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
|
||||||
CondFunc('torch.instance_norm',
|
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
|
||||||
|
|
||||||
# Functions with dtype errors:
|
|
||||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
# Training:
|
|
||||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
# BF16:
|
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
|
||||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
|
||||||
weight is not None and input.dtype != weight.data.dtype)
|
|
||||||
# SwinIR BF16:
|
# SwinIR BF16:
|
||||||
CondFunc('torch.nn.functional.pad',
|
original_funtional_pad = torch.nn.functional.pad
|
||||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
def funtional_pad(input, pad, mode='constant', value=None):
|
||||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||||
|
return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||||
|
else:
|
||||||
|
return original_funtional_pad(input, pad, mode=mode, value=value)
|
||||||
|
|
||||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
|
||||||
if not torch.xpu.has_fp64_dtype():
|
|
||||||
CondFunc('torch.from_numpy',
|
|
||||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
|
||||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
|
||||||
|
|
||||||
# Broken functions when torch.cuda.is_available is True:
|
original_torch_tensor = torch.tensor
|
||||||
# Pin Memory:
|
def torch_tensor(*args, device=None, **kwargs):
|
||||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
if check_device(device):
|
||||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||||
lambda orig_func, *args, **kwargs: True)
|
else:
|
||||||
|
return original_torch_tensor(*args, device=device, **kwargs)
|
||||||
|
|
||||||
# Functions that make compile mad with CondFunc:
|
original_Tensor_to = torch.Tensor.to
|
||||||
torch.nn.DataParallel = DummyDataParallel
|
def Tensor_to(self, device=None, *args, **kwargs):
|
||||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
if check_device(device):
|
||||||
|
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_Tensor_to(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_Tensor_cuda = torch.Tensor.cuda
|
||||||
|
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||||
|
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||||
|
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_torch_empty = torch.empty
|
||||||
|
def torch_empty(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_empty(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_randn = torch.randn
|
||||||
|
def torch_randn(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_randn(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_ones = torch.ones
|
||||||
|
def torch_ones(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_ones(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_zeros = torch.zeros
|
||||||
|
def torch_zeros(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_zeros(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_linspace = torch.linspace
|
||||||
|
def torch_linspace(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_linspace(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_Generator = torch.Generator
|
||||||
|
def torch_Generator(device=None):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_Generator(return_xpu(device))
|
||||||
|
else:
|
||||||
|
return original_torch_Generator(device)
|
||||||
|
|
||||||
|
original_torch_load = torch.load
|
||||||
|
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||||
|
if check_device(map_location):
|
||||||
|
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||||
|
|
||||||
|
# Hijack Functions:
|
||||||
|
def ipex_hijacks():
|
||||||
|
torch.tensor = torch_tensor
|
||||||
|
torch.Tensor.to = Tensor_to
|
||||||
|
torch.Tensor.cuda = Tensor_cuda
|
||||||
|
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||||
|
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||||
|
torch.empty = torch_empty
|
||||||
|
torch.randn = torch_randn
|
||||||
|
torch.ones = torch_ones
|
||||||
|
torch.zeros = torch_zeros
|
||||||
|
torch.linspace = torch_linspace
|
||||||
|
torch.Generator = torch_Generator
|
||||||
|
torch.load = torch_load
|
||||||
|
|
||||||
torch.autocast = ipex_autocast
|
|
||||||
torch.backends.cuda.sdp_kernel = return_null_context
|
torch.backends.cuda.sdp_kernel = return_null_context
|
||||||
|
torch.nn.DataParallel = DummyDataParallel
|
||||||
torch.UntypedStorage.is_cuda = is_cuda
|
torch.UntypedStorage.is_cuda = is_cuda
|
||||||
|
torch.autocast = ipex_autocast
|
||||||
|
|
||||||
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||||
|
torch.nn.functional.group_norm = functional_group_norm
|
||||||
|
torch.nn.functional.layer_norm = functional_layer_norm
|
||||||
|
torch.nn.functional.linear = functional_linear
|
||||||
|
torch.nn.functional.conv2d = functional_conv2d
|
||||||
torch.nn.functional.interpolate = interpolate
|
torch.nn.functional.interpolate = interpolate
|
||||||
torch.linalg.solve = linalg_solve
|
torch.nn.functional.pad = funtional_pad
|
||||||
|
|
||||||
torch.bmm = torch_bmm
|
torch.bmm = torch_bmm
|
||||||
torch.cat = torch_cat
|
torch.cat = torch_cat
|
||||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
if not torch.xpu.has_fp64_dtype():
|
||||||
|
torch.from_numpy = from_numpy
|
||||||
|
|||||||
Reference in New Issue
Block a user