mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Deduplicate ipex initialization code
This commit is contained in:
@@ -1,11 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
|
||||||
|
|||||||
@@ -11,15 +11,10 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
|||||||
@@ -66,15 +66,10 @@ import diffusers
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
|||||||
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def init_ipex():
|
||||||
|
"""
|
||||||
|
Try to import `intel_extension_for_pytorch`, and apply
|
||||||
|
the hijacks using `library.ipex.ipex_init`.
|
||||||
|
|
||||||
|
If IPEX is not installed, this function does nothing.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex # noqa
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
is_initialized, error_message = ipex_init()
|
||||||
|
if not is_initialized:
|
||||||
|
print("failed to initialize ipex:", error_message)
|
||||||
|
except Exception as e:
|
||||||
|
print("failed to initialize ipex:", e)
|
||||||
@@ -5,15 +5,9 @@ import math
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||||
|
|||||||
@@ -18,15 +18,10 @@ import diffusers
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
|||||||
@@ -9,13 +9,11 @@ import random
|
|||||||
from einops import repeat
|
from einops import repeat
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
from library.ipex_interop import init_ipex
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
init_ipex()
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from diffusers import EulerDiscreteScheduler
|
from diffusers import EulerDiscreteScheduler
|
||||||
|
|||||||
@@ -11,15 +11,10 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import sdxl_model_util
|
from library import sdxl_model_util
|
||||||
|
|||||||
@@ -14,13 +14,11 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
from library.ipex_interop import init_ipex
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
init_ipex()
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|||||||
@@ -11,13 +11,11 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
from library.ipex_interop import init_ipex
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
init_ipex()
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
|
|||||||
@@ -1,15 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
import train_network
|
import train_network
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,9 @@ import os
|
|||||||
|
|
||||||
import regex
|
import regex
|
||||||
import torch
|
import torch
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import open_clip
|
import open_clip
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
|
|
||||||
|
|||||||
@@ -12,15 +12,10 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
|
|||||||
@@ -12,15 +12,10 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
|||||||
@@ -14,15 +14,10 @@ from tqdm import tqdm
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import model_util
|
from library import model_util
|
||||||
|
|||||||
@@ -8,15 +8,10 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from library.ipex_interop import init_ipex
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
init_ipex()
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|||||||
@@ -8,13 +8,11 @@ from multiprocessing import Value
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
from library.ipex_interop import init_ipex
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
init_ipex()
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|||||||
Reference in New Issue
Block a user