diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index f8f6ddd9..3bf71fed 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -153,7 +153,7 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, providers=(["OpenVINOExecutionProvider"]), - provider_options=[{'device_type' : "GPU_FP32"}], + provider_options=[{'device_type' : "GPU", "precision": "FP32"}], ) else: ort_sess = ort.InferenceSession( diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index a36664bb..a44531f3 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -1,14 +1,15 @@ import os import sys -import contextlib import torch try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import - legacy = True + has_ipex = True except Exception: - legacy = False + has_ipex = False from .hijacks import ipex_hijacks +torch_version = float(torch.__version__[:3]) + # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements @@ -16,7 +17,10 @@ def ipex_init(): # pylint: disable=too-many-statements if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: return True, "Skipping IPEX hijack" else: - try: # force xpu device on torch compile and triton + try: + # force xpu device on torch compile and triton + # import inductor utils to get around lazy import + from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401 torch._inductor.utils.GPU_TYPES = ["xpu"] torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu" from triton import backends as triton_backends # pylint: disable=import-error @@ -35,7 +39,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_available = torch.xpu.is_available torch.cuda.is_initialized = torch.xpu.is_initialized torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device torch.cuda.stream = torch.xpu.stream torch.cuda.Event = torch.xpu.Event torch.cuda.Stream = torch.xpu.Stream @@ -45,7 +48,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.Optional = torch.xpu.Optional torch.cuda.__cached__ = torch.xpu.__cached__ torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.streams = torch.xpu.streams torch.cuda.Any = torch.xpu.Any torch.cuda.__doc__ = torch.xpu.__doc__ @@ -58,7 +60,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.List = torch.xpu.List torch.cuda._lazy_init = torch.xpu._lazy_init torch.cuda.StreamContext = torch.xpu.StreamContext torch.cuda._lazy_call = torch.xpu._lazy_call @@ -70,47 +71,40 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.__file__ = torch.xpu.__file__ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - if legacy: - torch.cuda.os = torch.xpu.os - torch.cuda.Device = torch.xpu.Device - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.classproperty = torch.xpu.classproperty - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - if float(ipex.__version__[:3]) < 2.3: - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda._lazy_new = torch.xpu._lazy_new + if torch_version < 2.3: + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - - if not legacy or float(ipex.__version__[:3]) >= 2.3: + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + else: torch.cuda._initialization_lock = torch.xpu._initialization_lock torch.cuda._initialized = torch.xpu._initialized torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork @@ -120,12 +114,24 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.threading = torch.xpu.threading torch.cuda.traceback = torch.xpu.traceback + if torch_version < 2.5: + torch.cuda.os = torch.xpu.os + torch.cuda.Device = torch.xpu.Device + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.classproperty = torch.xpu.classproperty + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + + if torch_version < 2.7: + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.List = torch.xpu.List + + # Memory: if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache - if legacy: + if has_ipex: torch.cuda.memory_summary = torch.xpu.memory_summary torch.cuda.memory_snapshot = torch.xpu.memory_snapshot torch.cuda.memory = torch.xpu.memory @@ -153,40 +159,19 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - # AMP: - if legacy: - torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd - torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd - torch.cuda.amp = torch.xpu.amp - if float(ipex.__version__[:3]) < 2.3: - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - # C - if legacy and float(ipex.__version__[:3]) < 2.3: + if torch_version < 2.3: torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.major = 12 ipex._C._DeviceProperties.minor = 1 + ipex._C._DeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750 else: torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count torch._C._XpuDeviceProperties.major = 12 torch._C._XpuDeviceProperties.minor = 1 + torch._C._XpuDeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750 # Fix functions with ipex: # torch.xpu.mem_get_info always returns the total memory as free memory @@ -195,21 +180,22 @@ def ipex_init(): # pylint: disable=too-many-statements torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True) torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.backends.cuda.is_built = lambda *args, **kwargs: True torch.version.cuda = "12.1" - torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"] + torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"]) torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1) torch.cuda.get_device_properties.major = 12 torch.cuda.get_device_properties.minor = 1 + torch.cuda.get_device_properties.L2_cache_size = 16*1024*1024 # A770 and A750 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy) + device_supports_fp64 = ipex_hijacks() try: from .diffusers import ipex_diffusers - ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb) + ipex_diffusers(device_supports_fp64=device_supports_fp64) except Exception: # pylint: disable=broad-exception-caught pass torch.cuda.is_xpu_hijacked = True diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 400b59b6..177f5bc5 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -61,13 +61,13 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop if query.device.type != "xpu": return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) is_unsqueezed = False - if len(query.shape) == 3: + if query.dim() == 3: query = query.unsqueeze(0) is_unsqueezed = True - if len(key.shape) == 3: - key = key.unsqueeze(0) - if len(value.shape) == 3: - value = value.unsqueeze(0) + if key.dim() == 3: + key = key.unsqueeze(0) + if value.dim() == 3: + value = value.unsqueeze(0) do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate) # Slice SDPA @@ -115,5 +115,5 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop else: hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) if is_unsqueezed: - hidden_states.squeeze(0) + hidden_states = hidden_states.squeeze(0) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 75715d16..d3487fef 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,11 +1,13 @@ from functools import wraps import torch import diffusers # pylint: disable=import-error +from diffusers.utils import torch_utils # pylint: disable=import-error, unused-import # noqa: F401 # pylint: disable=protected-access, missing-function-docstring, line-too-long # Diffusers FreeU +# Diffusers is imported before ipex hijacks so fourier_filter needs hijacking too original_fourier_filter = diffusers.utils.torch_utils.fourier_filter @wraps(diffusers.utils.torch_utils.fourier_filter) def fourier_filter(x_in, threshold, scale): @@ -41,7 +43,84 @@ class FluxPosEmbed(torch.nn.Module): return freqs_cos, freqs_sin -def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): +def hidream_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + return_device = pos.device + pos = pos.to("cpu") + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.to(return_device, dtype=torch.float32) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): + if output_type == "np": + return diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1): + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + else: + # used for lumina + # force cpu with Alchemist + x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.to("cpu").unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x).to(x.device) + + +def ipex_diffusers(device_supports_fp64=False): diffusers.utils.torch_utils.fourier_filter = fourier_filter if not device_supports_fp64: + # get around lazy imports + from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401 + from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401 + from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401 + diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed + diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb + diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed + diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb + diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed + diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py deleted file mode 100644 index 0a861009..00000000 --- a/library/ipex/gradscaler.py +++ /dev/null @@ -1,183 +0,0 @@ -from collections import defaultdict -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 -OptState = ipex.cpu.autocast._grad_scaler.OptState -_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator -_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state - -def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument - per_device_inv_scale = _MultiDeviceReplicator(inv_scale) - per_device_found_inf = _MultiDeviceReplicator(found_inf) - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] - # sync grad to master weight - if hasattr(optimizer, "sync_grad"): - optimizer.sync_grad() - with torch.no_grad(): - for group in optimizer.param_groups: - for param in group["params"]: - if param.grad is None: - continue - if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError("Attempting to unscale FP16 gradients.") - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - - # -: is there a way to split by device and dtype without appending in the inner loop? - to_unscale = to_unscale.to("cpu") - per_device_and_dtype_grads[to_unscale.device][ - to_unscale.dtype - ].append(to_unscale) - - for _, per_dtype_grads in per_device_and_dtype_grads.items(): - for grads in per_dtype_grads.values(): - core._amp_foreach_non_finite_check_and_unscale_( - grads, - per_device_found_inf.get("cpu"), - per_device_inv_scale.get("cpu"), - ) - - return per_device_found_inf._per_device_tensors - -def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update()." - ) - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - if device_supports_fp64: - inv_scale = self._scale.double().reciprocal().float() - else: - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=self._scale.device - ) - - optimizer_state["found_inf_per_device"] = self._unscale_grads_( - optimizer, inv_scale, found_inf, False - ) - optimizer_state["stage"] = OptState.UNSCALED - -def update(self, new_scale=None): - """ - Updates the scale factor. - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - Args: - new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device="cpu", non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - to_device = _scale.device - _scale = _scale.to("cpu") - _growth_tracker = _growth_tracker.to("cpu") - - core._amp_update_scale_( - _scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval, - ) - - _scale = _scale.to(to_device) - _growth_tracker = _growth_tracker.to(to_device) - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - -def gradscaler_init(): - torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ - torch.xpu.amp.GradScaler.unscale_ = unscale_ - torch.xpu.amp.GradScaler.update = update - return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 91569746..29df78e9 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -4,17 +4,23 @@ from contextlib import nullcontext import torch import numpy as np -device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 -if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1: - try: - x = torch.ones((33000,33000), dtype=torch.float32, device="xpu") - del x - torch.xpu.empty_cache() - can_allocate_plus_4gb = True - except Exception: - can_allocate_plus_4gb = False +torch_version = float(torch.__version__[:3]) +current_xpu_device = f"xpu:{torch.xpu.current_device()}" +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(current_xpu_device).has_fp64 + +if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0': + if (torch.xpu.get_device_properties(current_xpu_device).total_memory / 1024 / 1024 / 1024) > 4.1: + try: + x = torch.ones((33000,33000), dtype=torch.float32, device=current_xpu_device) + del x + torch.xpu.empty_cache() + use_dynamic_attention = False + except Exception: + use_dynamic_attention = True + else: + use_dynamic_attention = True else: - can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1') + use_dynamic_attention = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '1') # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -22,32 +28,67 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: print("IPEX backend doesn't support DataParallel on multiple XPU devices") - return module.to("xpu") + return module.to(f"xpu:{torch.xpu.current_device()}") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument return nullcontext() @property def is_cuda(self): - return self.device.type == 'xpu' or self.device.type == 'cuda' + return self.device.type == "xpu" or self.device.type == "cuda" -def check_device(device): - return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) +def check_device_type(device, device_type: str) -> bool: + if device is None or type(device) not in {str, int, torch.device}: + return False + else: + return bool(torch.device(device).type == device_type) -def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu" +def check_cuda(device) -> bool: + return bool(isinstance(device, int) or check_device_type(device, "cuda")) + +def return_xpu(device): # keep the device instance type, aka return string if the input is string + return f"xpu:{torch.xpu.current_device()}" if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu" # Autocast original_autocast_init = torch.amp.autocast_mode.autocast.__init__ @wraps(torch.amp.autocast_mode.autocast.__init__) -def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): - if device_type == "cuda": +def autocast_init(self, device_type=None, dtype=None, enabled=True, cache_enabled=None): + if device_type is None or check_cuda(device_type): return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) + +original_grad_scaler_init = torch.amp.grad_scaler.GradScaler.__init__ +@wraps(torch.amp.grad_scaler.GradScaler.__init__) +def GradScaler_init(self, device: str = None, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True): + if device is None or check_cuda(device): + return original_grad_scaler_init(self, device=return_xpu(device), init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled) + else: + return original_grad_scaler_init(self, device=device, init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled) + + +original_is_autocast_enabled = torch.is_autocast_enabled +@wraps(torch.is_autocast_enabled) +def torch_is_autocast_enabled(device_type=None): + if device_type is None or check_cuda(device_type): + return original_is_autocast_enabled(return_xpu(device_type)) + else: + return original_is_autocast_enabled(device_type) + + +original_get_autocast_dtype = torch.get_autocast_dtype +@wraps(torch.get_autocast_dtype) +def torch_get_autocast_dtype(device_type=None): + if device_type is None or check_cuda(device_type) or check_device_type(device_type, "xpu"): + return torch.bfloat16 + else: + return original_get_autocast_dtype(device_type) + + # Latent Antialias CPU Offload: +# IPEX 2.5 and above has partial support but doesn't really work most of the time. original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments @@ -66,23 +107,22 @@ original_from_numpy = torch.from_numpy @wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: - return original_from_numpy(ndarray.astype('float32')) + return original_from_numpy(ndarray.astype("float32")) else: return original_from_numpy(ndarray) original_as_tensor = torch.as_tensor @wraps(torch.as_tensor) def as_tensor(data, dtype=None, device=None): - if check_device(device): + if check_cuda(device): device = return_xpu(device) - if isinstance(data, np.ndarray) and data.dtype == float and not ( - (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + if isinstance(data, np.ndarray) and data.dtype == float and not check_device_type(device, "cpu"): return original_as_tensor(data, dtype=torch.float32, device=device) else: return original_as_tensor(data, dtype=dtype, device=device) -if can_allocate_plus_4gb: +if not use_dynamic_attention: original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: # 32 bit attention workarounds for Alchemist: @@ -106,7 +146,7 @@ original_torch_bmm = torch.bmm @wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) + mat2 = mat2.to(dtype=input.dtype) return original_torch_bmm(input, mat2, out=out) # Diffusers FreeU @@ -195,38 +235,36 @@ original_torch_tensor = torch.tensor @wraps(torch.tensor) def torch_tensor(data, *args, dtype=None, device=None, **kwargs): global device_supports_fp64 - if check_device(device): + if check_cuda(device): device = return_xpu(device) if not device_supports_fp64: - if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if check_device_type(device, "xpu"): if dtype == torch.float64: dtype = torch.float32 elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): dtype = torch.float32 return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) -original_Tensor_to = torch.Tensor.to +torch.Tensor.original_Tensor_to = torch.Tensor.to @wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): - if check_device(device): - return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + if check_cuda(device): + return self.original_Tensor_to(return_xpu(device), *args, **kwargs) else: - return original_Tensor_to(self, device, *args, **kwargs) + return self.original_Tensor_to(device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda @wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): - if check_device(device): - return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + if device is None or check_cuda(device): + return self.to(return_xpu(device), *args, **kwargs) else: return original_Tensor_cuda(self, device, *args, **kwargs) original_Tensor_pin_memory = torch.Tensor.pin_memory @wraps(torch.Tensor.pin_memory) def Tensor_pin_memory(self, device=None, *args, **kwargs): - if device is None: - device = "xpu" - if check_device(device): + if device is None or check_cuda(device): return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs) else: return original_Tensor_pin_memory(self, device, *args, **kwargs) @@ -234,23 +272,32 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs): original_UntypedStorage_init = torch.UntypedStorage.__init__ @wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) else: return original_UntypedStorage_init(*args, device=device, **kwargs) -original_UntypedStorage_cuda = torch.UntypedStorage.cuda -@wraps(torch.UntypedStorage.cuda) -def UntypedStorage_cuda(self, device=None, *args, **kwargs): - if check_device(device): - return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) - else: - return original_UntypedStorage_cuda(self, device, *args, **kwargs) +if torch_version >= 2.4: + original_UntypedStorage_to = torch.UntypedStorage.to + @wraps(torch.UntypedStorage.to) + def UntypedStorage_to(self, *args, device=None, **kwargs): + if check_cuda(device): + return original_UntypedStorage_to(self, *args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_to(self, *args, device=device, **kwargs) + + original_UntypedStorage_cuda = torch.UntypedStorage.cuda + @wraps(torch.UntypedStorage.cuda) + def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs): + if device is None or check_cuda(device): + return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs) + else: + return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs) original_torch_empty = torch.empty @wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) else: return original_torch_empty(*args, device=device, **kwargs) @@ -260,7 +307,7 @@ original_torch_randn = torch.randn def torch_randn(*args, device=None, dtype=None, **kwargs): if dtype is bytes: dtype = None - if check_device(device): + if check_cuda(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) else: return original_torch_randn(*args, device=device, **kwargs) @@ -268,7 +315,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs): original_torch_ones = torch.ones @wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) else: return original_torch_ones(*args, device=device, **kwargs) @@ -276,7 +323,7 @@ def torch_ones(*args, device=None, **kwargs): original_torch_zeros = torch.zeros @wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) else: return original_torch_zeros(*args, device=device, **kwargs) @@ -284,7 +331,7 @@ def torch_zeros(*args, device=None, **kwargs): original_torch_full = torch.full @wraps(torch.full) def torch_full(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_torch_full(*args, device=return_xpu(device), **kwargs) else: return original_torch_full(*args, device=device, **kwargs) @@ -292,63 +339,91 @@ def torch_full(*args, device=None, **kwargs): original_torch_linspace = torch.linspace @wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): - if check_device(device): + if check_cuda(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) else: return original_torch_linspace(*args, device=device, **kwargs) +original_torch_eye = torch.eye +@wraps(torch.eye) +def torch_eye(*args, device=None, **kwargs): + if check_cuda(device): + return original_torch_eye(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_eye(*args, device=device, **kwargs) + original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): - if map_location is None: - map_location = "xpu" - if check_device(map_location): + if map_location is None or check_cuda(map_location): return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs) else: return original_torch_load(f, *args, map_location=map_location, **kwargs) -original_torch_Generator = torch.Generator -@wraps(torch.Generator) -def torch_Generator(device=None): - if check_device(device): - return original_torch_Generator(return_xpu(device)) - else: - return original_torch_Generator(device) - @wraps(torch.cuda.synchronize) def torch_cuda_synchronize(device=None): - if check_device(device): + if check_cuda(device): return torch.xpu.synchronize(return_xpu(device)) else: return torch.xpu.synchronize(device) +@wraps(torch.cuda.device) +def torch_cuda_device(device): + if check_cuda(device): + return torch.xpu.device(return_xpu(device)) + else: + return torch.xpu.device(device) + +@wraps(torch.cuda.set_device) +def torch_cuda_set_device(device): + if check_cuda(device): + torch.xpu.set_device(return_xpu(device)) + else: + torch.xpu.set_device(device) + +# torch.Generator has to be a class for isinstance checks +original_torch_Generator = torch.Generator +class torch_Generator(original_torch_Generator): + def __new__(self, device=None): + # can't hijack __init__ because of C override so use return super().__new__ + if check_cuda(device): + return super().__new__(self, return_xpu(device)) + else: + return super().__new__(self, device) + # Hijack Functions: -def ipex_hijacks(legacy=True): - global device_supports_fp64, can_allocate_plus_4gb - if legacy and float(torch.__version__[:3]) < 2.5: - torch.nn.functional.interpolate = interpolate +def ipex_hijacks(): + global device_supports_fp64 + if torch_version >= 2.4: + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.UntypedStorage.to = UntypedStorage_to torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda torch.Tensor.pin_memory = Tensor_pin_memory torch.UntypedStorage.__init__ = UntypedStorage_init - torch.UntypedStorage.cuda = UntypedStorage_cuda torch.empty = torch_empty torch.randn = torch_randn torch.ones = torch_ones torch.zeros = torch_zeros torch.full = torch_full torch.linspace = torch_linspace + torch.eye = torch_eye torch.load = torch_load - torch.Generator = torch_Generator torch.cuda.synchronize = torch_cuda_synchronize + torch.cuda.device = torch_cuda_device + torch.cuda.set_device = torch_cuda_set_device + + torch.Generator = torch_Generator + torch._C.Generator = torch_Generator torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda torch.amp.autocast_mode.autocast.__init__ = autocast_init + torch.nn.functional.interpolate = interpolate torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.layer_norm = functional_layer_norm @@ -364,4 +439,28 @@ def ipex_hijacks(legacy=True): if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor - return device_supports_fp64, can_allocate_plus_4gb + + # AMP: + torch.amp.grad_scaler.GradScaler.__init__ = GradScaler_init + torch.is_autocast_enabled = torch_is_autocast_enabled + torch.get_autocast_gpu_dtype = torch_get_autocast_dtype + torch.get_autocast_dtype = torch_get_autocast_dtype + + if hasattr(torch.xpu, "amp"): + if not hasattr(torch.xpu.amp, "custom_fwd"): + torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd + torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd + if not hasattr(torch.xpu.amp, "GradScaler"): + torch.xpu.amp.GradScaler = torch.amp.grad_scaler.GradScaler + torch.cuda.amp = torch.xpu.amp + else: + if not hasattr(torch.amp, "custom_fwd"): + torch.amp.custom_fwd = torch.cuda.amp.custom_fwd + torch.amp.custom_bwd = torch.cuda.amp.custom_bwd + torch.cuda.amp = torch.amp + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + + return device_supports_fp64