Fix IPEX support and add XPU device to device_utils

This commit is contained in:
Disty0
2024-01-31 17:32:37 +03:00
parent 2ca4d0c831
commit a6a2b5a867
27 changed files with 248 additions and 245 deletions

View File

@@ -1,7 +1,7 @@
import torch
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -8,11 +8,9 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate.utils import set_seed

View File

@@ -9,13 +9,16 @@ from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device
DEVICE = get_preferred_device()

View File

@@ -5,12 +5,15 @@ import re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.device_utils import get_preferred_device
DEVICE = get_preferred_device()

View File

@@ -8,14 +8,16 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
from library.device_utils import get_preferred_device
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(

View File

@@ -64,11 +64,9 @@ import re
import diffusers
import numpy as np
import torch
from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()
import torchvision

View File

@@ -13,11 +13,19 @@ try:
except Exception:
HAS_MPS = False
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
@@ -26,9 +34,30 @@ def clean_memory():
def get_preferred_device() -> torch.device:
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -9,6 +9,9 @@ from .hijacks import ipex_hijacks
def ipex_init(): # pylint: disable=too-many-statements
try:
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
@@ -170,6 +173,7 @@ def ipex_init(): # pylint: disable=too-many-statements
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e
return True, None

View File

@@ -1,24 +0,0 @@
import torch
def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return
try:
from library.ipex import ipex_init
if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -3,11 +3,11 @@
import math
import os
import torch
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel

View File

@@ -2,12 +2,15 @@ import argparse
import math
import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"

View File

@@ -30,7 +30,11 @@ from io import BytesIO
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
@@ -66,7 +70,6 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う

View File

@@ -9,9 +9,10 @@ from diffusers import UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel
import torch
from library.device_utils import get_preferred_device
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
def make_unet_conversion_map() -> Dict[str, str]:

View File

@@ -5,11 +5,13 @@ from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
import library.model_util as model_util
import lora
from library.device_utils import get_preferred_device
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う

View File

@@ -16,11 +16,9 @@ import re
import diffusers
import numpy as np
import torch
from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory, get_preferred_device
init_ipex()
import torchvision

View File

@@ -8,11 +8,9 @@ import os
import random
from einops import repeat
import numpy as np
import torch
from library.device_utils import get_preferred_device
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm

View File

@@ -8,11 +8,9 @@ from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate.utils import set_seed

View File

@@ -12,11 +12,9 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@@ -9,11 +9,9 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@@ -1,9 +1,7 @@
import argparse
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util

View File

@@ -2,10 +2,11 @@ import argparse
import os
import regex
import torch
from library.ipex_interop import init_ipex
import torch
from library.device_utils import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

View File

@@ -11,12 +11,13 @@ from typing import Dict, List
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torch import nn
from tqdm import tqdm
from PIL import Image
from library.device_utils import get_preferred_device
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):

View File

@@ -9,11 +9,9 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@@ -9,11 +9,9 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate.utils import set_seed

View File

@@ -10,14 +10,13 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util

View File

@@ -5,11 +5,9 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate.utils import set_seed

View File

@@ -6,11 +6,9 @@ import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from library.device_utils import clean_memory
from library.ipex_interop import init_ipex
from library.device_utils import init_ipex, clean_memory
init_ipex()
from accelerate.utils import set_seed