mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
IPEX torch.tensor FP64 workaround
This commit is contained in:
@@ -5,6 +5,8 @@ import torch
|
|||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||||
|
|
||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
@@ -49,6 +51,7 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
|||||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||||
|
|
||||||
|
|
||||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||||
original_from_numpy = torch.from_numpy
|
original_from_numpy = torch.from_numpy
|
||||||
@wraps(torch.from_numpy)
|
@wraps(torch.from_numpy)
|
||||||
@@ -69,7 +72,8 @@ def as_tensor(data, dtype=None, device=None):
|
|||||||
else:
|
else:
|
||||||
return original_as_tensor(data, dtype=dtype, device=device)
|
return original_as_tensor(data, dtype=dtype, device=device)
|
||||||
|
|
||||||
if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
|
||||||
|
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||||
original_torch_bmm = torch.bmm
|
original_torch_bmm = torch.bmm
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
else:
|
else:
|
||||||
@@ -159,11 +163,16 @@ def functional_pad(input, pad, mode='constant', value=None):
|
|||||||
|
|
||||||
original_torch_tensor = torch.tensor
|
original_torch_tensor = torch.tensor
|
||||||
@wraps(torch.tensor)
|
@wraps(torch.tensor)
|
||||||
def torch_tensor(*args, device=None, **kwargs):
|
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||||
if check_device(device):
|
if check_device(device):
|
||||||
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
device = return_xpu(device)
|
||||||
else:
|
if not device_supports_fp64:
|
||||||
return original_torch_tensor(*args, device=device, **kwargs)
|
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||||
|
if dtype == torch.float64:
|
||||||
|
dtype = torch.float32
|
||||||
|
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||||
|
dtype = torch.float32
|
||||||
|
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||||
|
|
||||||
original_Tensor_to = torch.Tensor.to
|
original_Tensor_to = torch.Tensor.to
|
||||||
@wraps(torch.Tensor.to)
|
@wraps(torch.Tensor.to)
|
||||||
@@ -253,6 +262,7 @@ def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False,
|
|||||||
else:
|
else:
|
||||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Hijack Functions:
|
# Hijack Functions:
|
||||||
def ipex_hijacks():
|
def ipex_hijacks():
|
||||||
torch.tensor = torch_tensor
|
torch.tensor = torch_tensor
|
||||||
@@ -283,6 +293,6 @@ def ipex_hijacks():
|
|||||||
|
|
||||||
torch.bmm = torch_bmm
|
torch.bmm = torch_bmm
|
||||||
torch.cat = torch_cat
|
torch.cat = torch_cat
|
||||||
if not torch.xpu.has_fp64_dtype():
|
if not device_supports_fp64:
|
||||||
torch.from_numpy = from_numpy
|
torch.from_numpy = from_numpy
|
||||||
torch.as_tensor = as_tensor
|
torch.as_tensor = as_tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user