mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
load models one by one
This commit is contained in:
@@ -20,6 +20,175 @@ from library import sdxl_model_util
|
||||
# region models
|
||||
|
||||
|
||||
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=dvc)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
|
||||
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
|
||||
mmdit_sd = {}
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(mmdit_prefix):
|
||||
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
return mmdit
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
state_dict: Dict,
|
||||
clip_l_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
clip_l_sd = None
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
clip_l_sd = {}
|
||||
prefix = "text_encoders.clip_l."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
if clip_l_sd is None:
|
||||
clip_l = None
|
||||
else:
|
||||
logger.info("Building ClipL")
|
||||
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded ClipL: {info}")
|
||||
clip_l.set_attn_mode(attn_mode)
|
||||
return clip_l
|
||||
|
||||
|
||||
def load_clip_g(
|
||||
state_dict: Dict,
|
||||
clip_g_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
clip_g_sd = None
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
clip_g_sd = {}
|
||||
prefix = "text_encoders.clip_g."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
if clip_g_sd is None:
|
||||
clip_g = None
|
||||
else:
|
||||
logger.info("Building ClipG")
|
||||
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded ClipG: {info}")
|
||||
clip_g.set_attn_mode(attn_mode)
|
||||
return clip_g
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
state_dict: Dict,
|
||||
t5xxl_path: Optional[str],
|
||||
attn_mode: str,
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
t5xxl_sd = None
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
t5xxl_sd = {}
|
||||
prefix = "text_encoders.t5xxl."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
if t5xxl_sd is None:
|
||||
t5xxl = None
|
||||
else:
|
||||
logger.info("Building T5XXL")
|
||||
|
||||
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
|
||||
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
|
||||
t5xxl.to(dtype=dtype)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded T5XXL: {info}")
|
||||
t5xxl.set_attn_mode(attn_mode)
|
||||
return t5xxl
|
||||
|
||||
|
||||
def load_vae(
|
||||
state_dict: Dict,
|
||||
vae_path: Optional[str],
|
||||
vae_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE from {vae_path}...")
|
||||
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
||||
else:
|
||||
# remove prefix "first_stage_model."
|
||||
vae_sd = {}
|
||||
vae_prefix = "first_stage_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(vae_prefix):
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype)
|
||||
return vae
|
||||
|
||||
|
||||
def load_models(
|
||||
ckpt_path: str,
|
||||
clip_l_path: str,
|
||||
|
||||
Reference in New Issue
Block a user