load models one by one

This commit is contained in:
Kohya S
2024-07-08 22:04:43 +09:00
parent c9de7c4e9a
commit 3ea4fce5e0
3 changed files with 236 additions and 47 deletions

View File

@@ -20,6 +20,175 @@ from library import sdxl_model_util
# region models
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load MMDiT
logger.info("Building MMDit")
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
state_dict: Dict,
clip_l_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
return clip_l
def load_clip_g(
state_dict: Dict,
clip_g_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
return clip_g
def load_t5xxl(
state_dict: Dict,
t5xxl_path: Optional[str],
attn_mode: str,
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
t5xxl.to(dtype=dtype)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
return t5xxl
def load_vae(
state_dict: Dict,
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
):
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return vae
def load_models(
ckpt_path: str,
clip_l_path: str,