Files
Kohya-ss-sd-scripts/library/device_utils.py

152 lines
4.2 KiB
Python

import functools
import gc
from typing import Optional, Union
import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
import intel_extension_for_pytorch as ipex # noqa
except Exception:
pass
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False
try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False
try:
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
def synchronize_device(device: Optional[Union[str, torch.device]]):
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
def _normalize_cuda_arch(arch) -> Optional[str]:
if isinstance(arch, str):
return arch if arch.startswith("sm_") else None
if isinstance(arch, (tuple, list)) and len(arch) >= 2:
return f"sm_{int(arch[0])}{int(arch[1])}"
return None
def validate_cuda_device_compatibility(device: Optional[Union[str, torch.device]] = None):
if not HAS_CUDA:
return
if device is None:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
if device.type != "cuda":
return
get_arch_list = getattr(torch.cuda, "get_arch_list", None)
if get_arch_list is None:
return
try:
supported_arches = sorted(
{arch_name for arch_name in (_normalize_cuda_arch(arch) for arch in get_arch_list()) if arch_name is not None}
)
device_arch = _normalize_cuda_arch(torch.cuda.get_device_capability(device))
device_name = torch.cuda.get_device_name(device)
except Exception:
return
if supported_arches and device_arch is not None and device_arch not in supported_arches:
cuda_version = getattr(torch.version, "cuda", None)
cuda_suffix = f" with CUDA {cuda_version}" if cuda_version else ""
supported = ", ".join(supported_arches)
raise RuntimeError(
f"CUDA device '{device_name}' reports {device_arch}, but this PyTorch build{cuda_suffix} only supports {supported}. "
+ "Install a PyTorch build that includes kernels for this GPU from https://pytorch.org/get-started/locally/ or build PyTorch from source."
)
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If xpu is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)