Fix IPEX support and add XPU device to device_utils

This commit is contained in:
Disty0
2024-01-31 17:32:37 +03:00
parent 2ca4d0c831
commit a6a2b5a867
27 changed files with 248 additions and 245 deletions

View File

@@ -30,7 +30,11 @@ from io import BytesIO
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
@@ -66,7 +70,6 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う