mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Modifying the method for get the Torch version
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from packaging import version
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
has_ipex = True
|
||||
@@ -8,7 +9,7 @@ except Exception:
|
||||
has_ipex = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
torch_version = float(torch.__version__[:3])
|
||||
torch_version = version.parse(torch.__version__)
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
@@ -71,7 +72,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if torch_version < 2.3:
|
||||
if torch_version < version.parse("2.3"):
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
@@ -114,14 +115,14 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
if torch_version < 2.5:
|
||||
if torch_version < version.parse("2.5"):
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
|
||||
if torch_version < 2.7:
|
||||
if torch_version < version.parse("2.7"):
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.List = torch.xpu.List
|
||||
|
||||
@@ -160,7 +161,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# C
|
||||
if torch_version < 2.3:
|
||||
if torch_version < version.parse("2.3"):
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
|
||||
Reference in New Issue
Block a user