mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
IPEX fix pin_memory
This commit is contained in:
@@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||||
torch.Tensor.cuda = torch.Tensor.xpu
|
torch.Tensor.cuda = torch.Tensor.xpu
|
||||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||||
|
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||||
@@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
|
|
||||||
# C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2024
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 0
|
||||||
|
|
||||||
# Fix functions with ipex:
|
# Fix functions with ipex:
|
||||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||||
|
|||||||
@@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
||||||
|
@wraps(torch.Tensor.pin_memory)
|
||||||
|
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||||
|
if device is None:
|
||||||
|
device = "xpu"
|
||||||
|
if check_device(device):
|
||||||
|
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
||||||
|
|
||||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||||
@wraps(torch.UntypedStorage.__init__)
|
@wraps(torch.UntypedStorage.__init__)
|
||||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||||
@@ -259,10 +269,12 @@ def torch_Generator(device=None):
|
|||||||
original_torch_load = torch.load
|
original_torch_load = torch.load
|
||||||
@wraps(torch.load)
|
@wraps(torch.load)
|
||||||
def torch_load(f, map_location=None, *args, **kwargs):
|
def torch_load(f, map_location=None, *args, **kwargs):
|
||||||
|
if map_location is None:
|
||||||
|
map_location = "xpu"
|
||||||
if check_device(map_location):
|
if check_device(map_location):
|
||||||
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
|
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
||||||
else:
|
else:
|
||||||
return original_torch_load(f, map_location=map_location, *args, **kwargs)
|
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Hijack Functions:
|
# Hijack Functions:
|
||||||
@@ -270,6 +282,7 @@ def ipex_hijacks():
|
|||||||
torch.tensor = torch_tensor
|
torch.tensor = torch_tensor
|
||||||
torch.Tensor.to = Tensor_to
|
torch.Tensor.to = Tensor_to
|
||||||
torch.Tensor.cuda = Tensor_cuda
|
torch.Tensor.cuda = Tensor_cuda
|
||||||
|
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||||
torch.empty = torch_empty
|
torch.empty = torch_empty
|
||||||
|
|||||||
Reference in New Issue
Block a user