mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Refactor device determination to function; add MPS fallback
This commit is contained in:
@@ -15,8 +15,9 @@ from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder, is_url
|
||||
import library.train_util as train_util
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
|
||||
IMAGE_SIZE = 384
|
||||
|
||||
@@ -10,9 +10,9 @@ from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
PATTERN_REPLACE = [
|
||||
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
||||
|
||||
@@ -14,7 +14,9 @@ from torchvision import transforms
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
|
||||
@@ -66,7 +66,7 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import clean_memory
|
||||
from library.device_utils import clean_memory, get_preferred_device
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
@@ -2324,7 +2324,7 @@ def main(args):
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
# deviceを決定する
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||
device = get_preferred_device()
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
|
||||
@@ -1,9 +1,34 @@
|
||||
import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
except Exception:
|
||||
HAS_CUDA = False
|
||||
|
||||
try:
|
||||
HAS_MPS = torch.backends.mps.is_available()
|
||||
except Exception:
|
||||
HAS_MPS = False
|
||||
|
||||
|
||||
def clean_memory():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if HAS_CUDA:
|
||||
torch.cuda.empty_cache()
|
||||
if HAS_MPS:
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_preferred_device() -> torch.device:
|
||||
if HAS_CUDA:
|
||||
device = torch.device("cuda")
|
||||
elif HAS_MPS:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"get_preferred_device() -> {device}")
|
||||
return device
|
||||
|
||||
@@ -11,6 +11,8 @@ from tqdm import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
@@ -476,7 +478,7 @@ if __name__ == "__main__":
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
||||
|
||||
@@ -9,11 +9,12 @@ import torch
|
||||
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
|
||||
@@ -18,7 +18,7 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import clean_memory
|
||||
from library.device_utils import clean_memory, get_preferred_device
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
@@ -1495,7 +1495,7 @@ def main(args):
|
||||
# scheduler.config.clip_sample = True
|
||||
|
||||
# deviceを決定する
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||
device = get_preferred_device()
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
|
||||
@@ -10,6 +10,7 @@ from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.device_utils import get_preferred_device
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
@@ -85,7 +86,7 @@ if __name__ == "__main__":
|
||||
guidance_scale = 7
|
||||
seed = None # 1
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE = get_preferred_device()
|
||||
DTYPE = torch.float16 # bfloat16 may work
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -15,6 +15,8 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import get_preferred_device
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
@@ -255,7 +257,7 @@ def create_upscaler(**kwargs):
|
||||
|
||||
# another interface: upscale images with a model for given images from command line
|
||||
def upscale_images(args: argparse.Namespace):
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEVICE = get_preferred_device()
|
||||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user