refactor SD3 CLIP to transformers etc.

This commit is contained in:
Kohya S
2024-10-24 19:49:28 +09:00
parent 138dac4aea
commit 623017f716
13 changed files with 1201 additions and 2150 deletions

View File

@@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
@@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
@@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix