mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
def add_logging_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--console_log_level",
|
||||
@@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region PyTorch utils
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
Convert a string to a torch.dtype
|
||||
@@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen:
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(path, device=device)
|
||||
except:
|
||||
state_dict = load_file(path) # prevent device invalid Error
|
||||
if dtype is not None:
|
||||
for key in state_dict.keys():
|
||||
state_dict[key] = state_dict[key].to(dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
@@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
# region Gradual Latent hires fix
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user