Update IPEX libs

This commit is contained in:
Disty0
2025-06-13 16:25:16 +03:00
parent d8717a3d1c
commit bcd3a5a60a
6 changed files with 316 additions and 335 deletions

View File

@@ -153,7 +153,7 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
)
else:
ort_sess = ort.InferenceSession(

View File

@@ -1,14 +1,15 @@
import os
import sys
import contextlib
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
legacy = True
has_ipex = True
except Exception:
legacy = False
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3])
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
@@ -16,7 +17,10 @@ def ipex_init(): # pylint: disable=too-many-statements
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
try: # force xpu device on torch compile and triton
try:
# force xpu device on torch compile and triton
# import inductor utils to get around lazy import
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
from triton import backends as triton_backends # pylint: disable=import-error
@@ -35,7 +39,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
@@ -45,7 +48,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
@@ -58,7 +60,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
@@ -70,47 +71,40 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if legacy:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if float(ipex.__version__[:3]) < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new
if torch_version < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
if not legacy or float(ipex.__version__[:3]) >= 2.3:
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
else:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
@@ -120,12 +114,24 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7:
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
if legacy:
if has_ipex:
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory = torch.xpu.memory
@@ -153,40 +159,19 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
if legacy:
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp
if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
if legacy and float(ipex.__version__[:3]) < 2.3:
if torch_version < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
ipex._C._DeviceProperties.minor = 1
ipex._C._DeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
else:
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
torch._C._XpuDeviceProperties.major = 12
torch._C._XpuDeviceProperties.minor = 1
torch._C._XpuDeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
# Fix functions with ipex:
# torch.xpu.mem_get_info always returns the total memory as free memory
@@ -195,21 +180,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.get_device_properties.L2_cache_size = 16*1024*1024 # A770 and A750
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
device_supports_fp64 = ipex_hijacks()
try:
from .diffusers import ipex_diffusers
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
ipex_diffusers(device_supports_fp64=device_supports_fp64)
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True

View File

@@ -61,13 +61,13 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
is_unsqueezed = False
if len(query.shape) == 3:
if query.dim() == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if len(key.shape) == 3:
key = key.unsqueeze(0)
if len(value.shape) == 3:
value = value.unsqueeze(0)
if key.dim() == 3:
key = key.unsqueeze(0)
if value.dim() == 3:
value = value.unsqueeze(0)
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
# Slice SDPA
@@ -115,5 +115,5 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
else:
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
if is_unsqueezed:
hidden_states.squeeze(0)
hidden_states = hidden_states.squeeze(0)
return hidden_states

View File

@@ -1,11 +1,13 @@
from functools import wraps
import torch
import diffusers # pylint: disable=import-error
from diffusers.utils import torch_utils # pylint: disable=import-error, unused-import # noqa: F401
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# Diffusers FreeU
# Diffusers is imported before ipex hijacks so fourier_filter needs hijacking too
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
@wraps(diffusers.utils.torch_utils.fourier_filter)
def fourier_filter(x_in, threshold, scale):
@@ -41,7 +43,84 @@ class FluxPosEmbed(torch.nn.Module):
return freqs_cos, freqs_sin
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
def hidream_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
return_device = pos.device
pos = pos.to("cpu")
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.to(return_device, dtype=torch.float32)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
if output_type == "np":
return diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.outer(pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1):
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
# force cpu with Alchemist
x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.to("cpu").unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x).to(x.device)
def ipex_diffusers(device_supports_fp64=False):
diffusers.utils.torch_utils.fourier_filter = fourier_filter
if not device_supports_fp64:
# get around lazy imports
from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401
diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb
diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb
diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope

View File

@@ -1,183 +0,0 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

View File

@@ -4,17 +4,23 @@ from contextlib import nullcontext
import torch
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
del x
torch.xpu.empty_cache()
can_allocate_plus_4gb = True
except Exception:
can_allocate_plus_4gb = False
torch_version = float(torch.__version__[:3])
current_xpu_device = f"xpu:{torch.xpu.current_device()}"
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(current_xpu_device).has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0':
if (torch.xpu.get_device_properties(current_xpu_device).total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device=current_xpu_device)
del x
torch.xpu.empty_cache()
use_dynamic_attention = False
except Exception:
use_dynamic_attention = True
else:
use_dynamic_attention = True
else:
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
use_dynamic_attention = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '1')
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -22,32 +28,67 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
return module.to(f"xpu:{torch.xpu.current_device()}")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return nullcontext()
@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
return self.device.type == "xpu" or self.device.type == "cuda"
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def check_device_type(device, device_type: str) -> bool:
if device is None or type(device) not in {str, int, torch.device}:
return False
else:
return bool(torch.device(device).type == device_type)
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
def check_cuda(device) -> bool:
return bool(isinstance(device, int) or check_device_type(device, "cuda"))
def return_xpu(device): # keep the device instance type, aka return string if the input is string
return f"xpu:{torch.xpu.current_device()}" if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
# Autocast
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
def autocast_init(self, device_type=None, dtype=None, enabled=True, cache_enabled=None):
if device_type is None or check_cuda(device_type):
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
original_grad_scaler_init = torch.amp.grad_scaler.GradScaler.__init__
@wraps(torch.amp.grad_scaler.GradScaler.__init__)
def GradScaler_init(self, device: str = None, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True):
if device is None or check_cuda(device):
return original_grad_scaler_init(self, device=return_xpu(device), init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
else:
return original_grad_scaler_init(self, device=device, init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
original_is_autocast_enabled = torch.is_autocast_enabled
@wraps(torch.is_autocast_enabled)
def torch_is_autocast_enabled(device_type=None):
if device_type is None or check_cuda(device_type):
return original_is_autocast_enabled(return_xpu(device_type))
else:
return original_is_autocast_enabled(device_type)
original_get_autocast_dtype = torch.get_autocast_dtype
@wraps(torch.get_autocast_dtype)
def torch_get_autocast_dtype(device_type=None):
if device_type is None or check_cuda(device_type) or check_device_type(device_type, "xpu"):
return torch.bfloat16
else:
return original_get_autocast_dtype(device_type)
# Latent Antialias CPU Offload:
# IPEX 2.5 and above has partial support but doesn't really work most of the time.
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
@@ -66,23 +107,22 @@ original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
return original_from_numpy(ndarray.astype("float32"))
else:
return original_from_numpy(ndarray)
original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
if check_cuda(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
if isinstance(data, np.ndarray) and data.dtype == float and not check_device_type(device, "cpu"):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if can_allocate_plus_4gb:
if not use_dynamic_attention:
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
@@ -106,7 +146,7 @@ original_torch_bmm = torch.bmm
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
mat2 = mat2.to(dtype=input.dtype)
return original_torch_bmm(input, mat2, out=out)
# Diffusers FreeU
@@ -195,38 +235,36 @@ original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_device(device):
if check_cuda(device):
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if check_device_type(device, "xpu"):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to
torch.Tensor.original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
if check_cuda(device):
return self.original_Tensor_to(return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
return self.original_Tensor_to(device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
if device is None or check_cuda(device):
return self.to(return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
if device is None or check_cuda(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
@@ -234,23 +272,32 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
if torch_version >= 2.4:
original_UntypedStorage_to = torch.UntypedStorage.to
@wraps(torch.UntypedStorage.to)
def UntypedStorage_to(self, *args, device=None, **kwargs):
if check_cuda(device):
return original_UntypedStorage_to(self, *args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_to(self, *args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
if device is None or check_cuda(device):
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
else:
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
@@ -260,7 +307,7 @@ original_torch_randn = torch.randn
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype is bytes:
dtype = None
if check_device(device):
if check_cuda(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
@@ -268,7 +315,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
@@ -276,7 +323,7 @@ def torch_ones(*args, device=None, **kwargs):
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
@@ -284,7 +331,7 @@ def torch_zeros(*args, device=None, **kwargs):
original_torch_full = torch.full
@wraps(torch.full)
def torch_full(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_full(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_full(*args, device=device, **kwargs)
@@ -292,63 +339,91 @@ def torch_full(*args, device=None, **kwargs):
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_eye = torch.eye
@wraps(torch.eye)
def torch_eye(*args, device=None, **kwargs):
if check_cuda(device):
return original_torch_eye(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_eye(*args, device=device, **kwargs)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
if map_location is None or check_cuda(map_location):
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
@wraps(torch.cuda.synchronize)
def torch_cuda_synchronize(device=None):
if check_device(device):
if check_cuda(device):
return torch.xpu.synchronize(return_xpu(device))
else:
return torch.xpu.synchronize(device)
@wraps(torch.cuda.device)
def torch_cuda_device(device):
if check_cuda(device):
return torch.xpu.device(return_xpu(device))
else:
return torch.xpu.device(device)
@wraps(torch.cuda.set_device)
def torch_cuda_set_device(device):
if check_cuda(device):
torch.xpu.set_device(return_xpu(device))
else:
torch.xpu.set_device(device)
# torch.Generator has to be a class for isinstance checks
original_torch_Generator = torch.Generator
class torch_Generator(original_torch_Generator):
def __new__(self, device=None):
# can't hijack __init__ because of C override so use return super().__new__
if check_cuda(device):
return super().__new__(self, return_xpu(device))
else:
return super().__new__(self, device)
# Hijack Functions:
def ipex_hijacks(legacy=True):
global device_supports_fp64, can_allocate_plus_4gb
if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
def ipex_hijacks():
global device_supports_fp64
if torch_version >= 2.4:
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.UntypedStorage.to = UntypedStorage_to
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.full = torch_full
torch.linspace = torch_linspace
torch.eye = torch_eye
torch.load = torch_load
torch.Generator = torch_Generator
torch.cuda.synchronize = torch_cuda_synchronize
torch.cuda.device = torch_cuda_device
torch.cuda.set_device = torch_cuda_set_device
torch.Generator = torch_Generator
torch._C.Generator = torch_Generator
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.interpolate = interpolate
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
@@ -364,4 +439,28 @@ def ipex_hijacks(legacy=True):
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor
return device_supports_fp64, can_allocate_plus_4gb
# AMP:
torch.amp.grad_scaler.GradScaler.__init__ = GradScaler_init
torch.is_autocast_enabled = torch_is_autocast_enabled
torch.get_autocast_gpu_dtype = torch_get_autocast_dtype
torch.get_autocast_dtype = torch_get_autocast_dtype
if hasattr(torch.xpu, "amp"):
if not hasattr(torch.xpu.amp, "custom_fwd"):
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
if not hasattr(torch.xpu.amp, "GradScaler"):
torch.xpu.amp.GradScaler = torch.amp.grad_scaler.GradScaler
torch.cuda.amp = torch.xpu.amp
else:
if not hasattr(torch.amp, "custom_fwd"):
torch.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
return device_supports_fp64