mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #1266 from Zovjsra/feature/disable-mmap
Add "--disable_mmap_load_safetensors" parameter
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import safetensors
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
@@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
|||||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
||||||
# model_version is reserved for future use
|
# model_version is reserved for future use
|
||||||
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
||||||
|
|
||||||
# Load the state dict
|
# Load the state dict
|
||||||
if model_util.is_safetensors(ckpt_path):
|
if model_util.is_safetensors(ckpt_path):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
try:
|
if(disable_mmap):
|
||||||
state_dict = load_file(ckpt_path, device=map_location)
|
state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read())
|
||||||
except:
|
else:
|
||||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
try:
|
||||||
|
state_dict = load_file(ckpt_path, device=map_location)
|
||||||
|
except:
|
||||||
|
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||||
epoch = None
|
epoch = None
|
||||||
global_step = None
|
global_step = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
weight_dtype,
|
weight_dtype,
|
||||||
accelerator.device if args.lowram else "cpu",
|
accelerator.device if args.lowram else "cpu",
|
||||||
model_dtype,
|
model_dtype,
|
||||||
|
args.disable_mmap_load_safetensors
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
|
|
||||||
|
|
||||||
def _load_target_model(
|
def _load_target_model(
|
||||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
|
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
||||||
):
|
):
|
||||||
# model_dtype only work with full fp16/bf16
|
# model_dtype only work with full fp16/bf16
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
@@ -75,7 +76,7 @@ def _load_target_model(
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
from diffusers import StableDiffusionXLPipeline
|
from diffusers import StableDiffusionXLPipeline
|
||||||
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable_mmap_load_safetensors",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||||
|
|||||||
Reference in New Issue
Block a user