add disable_mmap to args

This commit is contained in:
Zovjsra
2024-04-16 16:40:08 +08:00
parent 71e2c91330
commit 64916a35b2
2 changed files with 16 additions and 7 deletions

View File

@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors
)
# work on low-ram device
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
@@ -75,7 +76,7 @@ def _load_target_model(
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):