update readme and help message etc.

This commit is contained in:
Kohya S
2024-05-12 17:55:08 +09:00
parent 8d1b1acd33
commit 9ddb4d7a01
3 changed files with 17 additions and 3 deletions

View File

@@ -9,8 +9,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionMode
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
@@ -171,8 +173,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
if(disable_mmap):
state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read())
if disable_mmap:
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
else:
try:
state_dict = load_file(ckpt_path, device=map_location)

View File

@@ -5,6 +5,7 @@ from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
@@ -13,8 +14,10 @@ from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -44,7 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors
args.disable_mmap_load_safetensors,
)
# work on low-ram device
@@ -336,6 +339,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)