Merge pull request #1054 from akx/mps

Device support improvements (MPS)
This commit is contained in:
Kohya S
2024-01-31 21:30:12 +09:00
committed by GitHub
22 changed files with 91 additions and 75 deletions

View File

@@ -11,6 +11,8 @@ from tqdm import tqdm
from transformers import CLIPTextModel
import torch
from library.device_utils import get_preferred_device
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
@@ -476,7 +478,7 @@ if __name__ == "__main__":
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")

View File

@@ -9,11 +9,12 @@ import torch
import library.model_util as model_util
import lora
from library.device_utils import get_preferred_device
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()
def interrogate(args):