feat: faster safetensors load and split safetensor utils

This commit is contained in:
Kohya S
2025-09-13 19:51:38 +09:00
parent 419a9c4af4
commit 8783f8aed3
17 changed files with 459 additions and 234 deletions

View File

@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
from library.safetensors_utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
@@ -124,7 +124,7 @@ def load_flow_model(
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers: