diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 51de31ad..7b2d26d4 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -5,6 +5,8 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np +device_supports_fp64 = torch.xpu.has_fp64_dtype() + # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods @@ -49,6 +51,7 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy @wraps(torch.from_numpy) @@ -69,7 +72,8 @@ def as_tensor(data, dtype=None, device=None): else: return original_as_tensor(data, dtype=dtype, device=device) -if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -159,11 +163,16 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) -def torch_tensor(*args, device=None, **kwargs): +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to @wraps(torch.Tensor.to) @@ -253,6 +262,7 @@ def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -283,6 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor