Merge pull request #1266 from Zovjsra/feature/disable-mmap

Add "--disable_mmap_load_safetensors" parameter
This commit is contained in:
Kohya S
2024-05-12 17:43:44 +09:00
committed by GitHub
2 changed files with 16 additions and 7 deletions

View File

@@ -1,4 +1,5 @@
import torch import torch
import safetensors
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
# model_version is reserved for future use # model_version is reserved for future use
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict # Load the state dict
if model_util.is_safetensors(ckpt_path): if model_util.is_safetensors(ckpt_path):
checkpoint = None checkpoint = None
try: if(disable_mmap):
state_dict = load_file(ckpt_path, device=map_location) state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read())
except: else:
state_dict = load_file(ckpt_path) # prevent device invalid Error try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None epoch = None
global_step = None global_step = None
else: else:

View File

@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype, weight_dtype,
accelerator.device if args.lowram else "cpu", accelerator.device if args.lowram else "cpu",
model_dtype, model_dtype,
args.disable_mmap_load_safetensors
) )
# work on low-ram device # work on low-ram device
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
def _load_target_model( def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
): ):
# model_dtype only work with full fp16/bf16 # model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
@@ -75,7 +76,7 @@ def _load_target_model(
unet, unet,
logit_scale, logit_scale,
ckpt_info, ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
else: else:
# Diffusers model is loaded to CPU # Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true", action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
) )
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):