mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Cleanup IPEX libs
This commit is contained in:
@@ -4,13 +4,12 @@ import contextlib
|
|||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
from .attention import attention_init
|
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
def ipex_init(): # pylint: disable=too-many-statements
|
def ipex_init(): # pylint: disable=too-many-statements
|
||||||
try:
|
try:
|
||||||
#Replace cuda with xpu:
|
# Replace cuda with xpu:
|
||||||
torch.cuda.current_device = torch.xpu.current_device
|
torch.cuda.current_device = torch.xpu.current_device
|
||||||
torch.cuda.current_stream = torch.xpu.current_stream
|
torch.cuda.current_stream = torch.xpu.current_stream
|
||||||
torch.cuda.device = torch.xpu.device
|
torch.cuda.device = torch.xpu.device
|
||||||
@@ -91,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||||
torch.cuda.__file__ = torch.xpu.__file__
|
torch.cuda.__file__ = torch.xpu.__file__
|
||||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||||
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||||
|
|
||||||
#Memory:
|
# Memory:
|
||||||
torch.cuda.memory = torch.xpu.memory
|
torch.cuda.memory = torch.xpu.memory
|
||||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||||
torch.xpu.empty_cache = lambda: None
|
torch.xpu.empty_cache = lambda: None
|
||||||
@@ -113,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||||
|
|
||||||
#RNG:
|
# RNG:
|
||||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||||
@@ -124,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.seed_all = torch.xpu.seed_all
|
torch.cuda.seed_all = torch.xpu.seed_all
|
||||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||||
|
|
||||||
#AMP:
|
# AMP:
|
||||||
torch.cuda.amp = torch.xpu.amp
|
torch.cuda.amp = torch.xpu.amp
|
||||||
if not hasattr(torch.cuda.amp, "common"):
|
if not hasattr(torch.cuda.amp, "common"):
|
||||||
torch.cuda.amp.common = contextlib.nullcontext()
|
torch.cuda.amp.common = contextlib.nullcontext()
|
||||||
@@ -139,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||||
|
|
||||||
#C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2023
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 2
|
||||||
|
|
||||||
#Fix functions with ipex:
|
# Fix functions with ipex:
|
||||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||||
torch._utils._get_available_device_type = lambda: "xpu"
|
torch._utils._get_available_device_type = lambda: "xpu"
|
||||||
torch.has_cuda = True
|
torch.has_cuda = True
|
||||||
@@ -166,7 +165,11 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
|
|
||||||
ipex_hijacks()
|
ipex_hijacks()
|
||||||
if not torch.xpu.has_fp64_dtype():
|
if not torch.xpu.has_fp64_dtype():
|
||||||
attention_init()
|
try:
|
||||||
|
from .attention import attention_init
|
||||||
|
attention_init()
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
from .diffusers import ipex_diffusers
|
from .diffusers import ipex_diffusers
|
||||||
ipex_diffusers()
|
ipex_diffusers()
|
||||||
|
|||||||
Reference in New Issue
Block a user