Merge pull request #985 from Disty0/dev

Update IPEX hijacks
This commit is contained in:
Kohya S
2023-12-07 21:39:24 +09:00
committed by GitHub
5 changed files with 63 additions and 27 deletions

View File

@@ -4,7 +4,6 @@ import contextlib
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
@@ -164,7 +164,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks() ipex_hijacks()
if not torch.xpu.has_fp64_dtype():
try:
from .attention import attention_init
attention_init() attention_init()
except Exception: # pylint: disable=broad-exception-caught
pass
try: try:
from .diffusers import ipex_diffusers from .diffusers import ipex_diffusers
ipex_diffusers() ipex_diffusers()

View File

@@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
shape_one, batch_size_attention, query_tokens, shape_four = query.shape shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False no_shape_one = False
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
block_multiply = query.element_size() block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size block_size = batch_size_attention * slice_block_size

View File

@@ -1,6 +1,6 @@
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long

View File

@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
@@ -96,6 +97,9 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None assert self._scale is not None
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full( found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device (1,), 0.0, dtype=torch.float32, device=self._scale.device

View File

@@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs):
else: else:
return original_autocast(*args, **kwargs) return original_autocast(*args, **kwargs)
# Embedding BF16
original_torch_cat = torch.cat original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs): def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
@@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs):
else: else:
return original_torch_cat(tensor, *args, **kwargs) return original_torch_cat(tensor, *args, **kwargs)
# Latent antialias:
original_interpolate = torch.nn.functional.interpolate original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None: if antialias or align_corners is not None:
@@ -115,19 +117,29 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
else: else:
return original_linalg_solve(A, B, *args, **kwargs) return original_linalg_solve(A, B, *args, **kwargs)
@property
def is_cuda(self):
return self.device.type == 'xpu'
def ipex_hijacks(): def ipex_hijacks():
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Tensor.to', CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda', CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.__init__',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty', CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device)) lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn', CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device)) lambda orig_func, *args, device=None, **kwargs: check_device(device))
@@ -137,17 +149,19 @@ def ipex_hijacks():
CondFunc('torch.zeros', CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device)) lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace', CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device)) lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
CondFunc('torch.Generator', CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device), lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
# TiledVAE and ControlNet:
CondFunc('torch.batch_norm', CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device), weight if weight is not None else torch.ones(input.size()[1], device=input.device),
@@ -163,25 +177,32 @@ def ipex_hijacks():
CondFunc('torch.nn.modules.GroupNorm.forward', CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# Training:
CondFunc('torch.nn.modules.linear.Linear.forward', CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward', CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# BF16:
CondFunc('torch.nn.functional.layer_norm', CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype) weight is not None and input.dtype != weight.data.dtype)
# SwinIR BF16:
CondFunc('torch.nn.functional.pad',
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64): # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
if not torch.xpu.has_fp64_dtype(): if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy', CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float) lambda orig_func, ndarray: ndarray.dtype == float)
# Broken functions when torch.cuda.is_available is True: # Broken functions when torch.cuda.is_available is True:
# Pin Memory:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True) lambda orig_func, *args, **kwargs: True)
@@ -192,5 +213,6 @@ def ipex_hijacks():
torch.autocast = ipex_autocast torch.autocast = ipex_autocast
torch.cat = torch_cat torch.cat = torch_cat
torch.linalg.solve = linalg_solve torch.linalg.solve = linalg_solve
torch.UntypedStorage.is_cuda = is_cuda
torch.nn.functional.interpolate = interpolate torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context torch.backends.cuda.sdp_kernel = return_null_context