import math from typing import Dict, Optional, Union import torch import safetensors from safetensors.torch import load_file from accelerate import init_empty_weights from .utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) from library import sd3_models # TODO move some of functions to model_util.py from library import sdxl_model_util # region models def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) else: try: return load_file(path, device=dvc) except: return load_file(path) # prevent device invalid Error def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): mmdit_sd = {} mmdit_prefix = "model.diffusion_model." for k in list(state_dict.keys()): if k.startswith(mmdit_prefix): mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) # load MMDiT logger.info("Building MMDit") with init_empty_weights(): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) logger.info(f"Loaded MMDiT: {info}") return mmdit def load_clip_l( state_dict: Dict, clip_l_path: Optional[str], attn_mode: str, clip_dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, ): clip_l_sd = None if clip_l_path: logger.info(f"Loading clip_l from {clip_l_path}...") clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) for key in list(clip_l_sd.keys()): clip_l_sd["transformer." + key] = clip_l_sd.pop(key) else: if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_l: remove prefix "text_encoders.clip_l." logger.info("clip_l is included in the checkpoint") clip_l_sd = {} prefix = "text_encoders.clip_l." for k in list(state_dict.keys()): if k.startswith(prefix): clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) if clip_l_sd is None: clip_l = None else: logger.info("Building ClipL") clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") clip_l.set_attn_mode(attn_mode) return clip_l def load_clip_g( state_dict: Dict, clip_g_path: Optional[str], attn_mode: str, clip_dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, ): clip_g_sd = None if clip_g_path: logger.info(f"Loading clip_g from {clip_g_path}...") clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) for key in list(clip_g_sd.keys()): clip_g_sd["transformer." + key] = clip_g_sd.pop(key) else: if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_g: remove prefix "text_encoders.clip_g." logger.info("clip_g is included in the checkpoint") clip_g_sd = {} prefix = "text_encoders.clip_g." for k in list(state_dict.keys()): if k.startswith(prefix): clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) if clip_g_sd is None: clip_g = None else: logger.info("Building ClipG") clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") clip_g.set_attn_mode(attn_mode) return clip_g def load_t5xxl( state_dict: Dict, t5xxl_path: Optional[str], attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, ): t5xxl_sd = None if t5xxl_path: logger.info(f"Loading t5xxl from {t5xxl_path}...") t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) for key in list(t5xxl_sd.keys()): t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) else: if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: # found t5xxl: remove prefix "text_encoders.t5xxl." logger.info("t5xxl is included in the checkpoint") t5xxl_sd = {} prefix = "text_encoders.t5xxl." for k in list(state_dict.keys()): if k.startswith(prefix): t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) if t5xxl_sd is None: t5xxl = None else: logger.info("Building T5XXL") # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) t5xxl.to(dtype=dtype) logger.info("Loading state dict...") info = t5xxl.load_state_dict(t5xxl_sd) logger.info(f"Loaded T5XXL: {info}") t5xxl.set_attn_mode(attn_mode) return t5xxl def load_vae( state_dict: Dict, vae_path: Optional[str], vae_dtype: Optional[Union[str, torch.dtype]], device: Optional[Union[str, torch.device]], disable_mmap: bool = False, ): vae_sd = {} if vae_path: logger.info(f"Loading VAE from {vae_path}...") vae_sd = load_safetensors(vae_path, device, disable_mmap) else: # remove prefix "first_stage_model." vae_sd = {} vae_prefix = "first_stage_model." for k in list(state_dict.keys()): if k.startswith(vae_prefix): vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) logger.info("Building VAE") vae = sd3_models.SDVAE() logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") vae.to(device=device, dtype=vae_dtype) return vae def load_models( ckpt_path: str, clip_l_path: str, clip_g_path: str, t5xxl_path: str, vae_path: str, attn_mode: str, device: Union[str, torch.device], weight_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, clip_dtype: Optional[Union[str, torch.dtype]] = None, t5xxl_device: Optional[Union[str, torch.device]] = None, t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, vae_dtype: Optional[Union[str, torch.dtype]] = None, ): """ Load SD3 models from checkpoint files. Args: ckpt_path: Path to the SD3 checkpoint file. clip_l_path: Path to the clip_l checkpoint file. clip_g_path: Path to the clip_g checkpoint file. t5xxl_path: Path to the t5xxl checkpoint file. vae_path: Path to the VAE checkpoint file. attn_mode: Attention mode for MMDiT model. device: Device for MMDiT model. weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. disable_mmap: Disable memory mapping when loading state dict. clip_dtype: Dtype for Clip models, or None to use default dtype. t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. vae_dtype: Dtype for VAE model, or None to use default dtype. Returns: Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. """ # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. # Therefore, we need clip_dtype and t5xxl_dtype. def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) else: try: return load_file(path, device=dvc) except: return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device clip_dtype = clip_dtype or weight_dtype or torch.float32 t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 vae_dtype = vae_dtype or weight_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) # load clip_l clip_l_sd = None if clip_l_path: logger.info(f"Loading clip_l from {clip_l_path}...") clip_l_sd = load_state_dict(clip_l_path) for key in list(clip_l_sd.keys()): clip_l_sd["transformer." + key] = clip_l_sd.pop(key) else: if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_l: remove prefix "text_encoders.clip_l." logger.info("clip_l is included in the checkpoint") clip_l_sd = {} prefix = "text_encoders.clip_l." for k in list(state_dict.keys()): if k.startswith(prefix): clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) # load clip_g clip_g_sd = None if clip_g_path: logger.info(f"Loading clip_g from {clip_g_path}...") clip_g_sd = load_state_dict(clip_g_path) for key in list(clip_g_sd.keys()): clip_g_sd["transformer." + key] = clip_g_sd.pop(key) else: if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_g: remove prefix "text_encoders.clip_g." logger.info("clip_g is included in the checkpoint") clip_g_sd = {} prefix = "text_encoders.clip_g." for k in list(state_dict.keys()): if k.startswith(prefix): clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) # load t5xxl t5xxl_sd = None if t5xxl_path: logger.info(f"Loading t5xxl from {t5xxl_path}...") t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) for key in list(t5xxl_sd.keys()): t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) else: if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: # found t5xxl: remove prefix "text_encoders.t5xxl." logger.info("t5xxl is included in the checkpoint") t5xxl_sd = {} prefix = "text_encoders.t5xxl." for k in list(state_dict.keys()): if k.startswith(prefix): t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) # MMDiT and VAE vae_sd = {} if vae_path: logger.info(f"Loading VAE from {vae_path}...") vae_sd = load_state_dict(vae_path) else: # remove prefix "first_stage_model." vae_sd = {} vae_prefix = "first_stage_model." for k in list(state_dict.keys()): if k.startswith(vae_prefix): vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) mmdit_prefix = "model.diffusion_model." for k in list(state_dict.keys()): if k.startswith(mmdit_prefix): state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) else: state_dict.pop(k) # remove other keys # load MMDiT logger.info("Building MMDit") with init_empty_weights(): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL if clip_l_sd is None: clip_l = None else: logger.info("Building ClipL") clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") clip_l.set_attn_mode(attn_mode) if clip_g_sd is None: clip_g = None else: logger.info("Building ClipG") clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") clip_g.set_attn_mode(attn_mode) # load T5XXL if t5xxl_sd is None: t5xxl = None else: logger.info("Building T5XXL") t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) logger.info("Loading state dict...") info = t5xxl.load_state_dict(t5xxl_sd) logger.info(f"Loaded T5XXL: {info}") t5xxl.set_attn_mode(attn_mode) # load VAE logger.info("Building VAE") vae = sd3_models.SDVAE() logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") vae.to(device=device, dtype=vae_dtype) return mmdit, clip_l, clip_g, t5xxl, vae # endregion # region utils def get_cond( prompt: str, tokenizer: sd3_models.SD3Tokenizer, clip_l: sd3_models.SDClipModel, clip_g: sd3_models.SDXLClipG, t5xxl: Optional[sd3_models.T5XXLModel] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) def get_cond_from_tokens( l_tokens, g_tokens, t5_tokens, clip_l: sd3_models.SDClipModel, clip_g: sd3_models.SDXLClipG, t5xxl: Optional[sd3_models.T5XXLModel] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): l_out, l_pooled = clip_l.encode_token_weights(l_tokens) g_out, g_pooled = clip_g.encode_token_weights(g_tokens) lg_out = torch.cat([l_out, g_out], dim=-1) lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) if device is not None: lg_out = lg_out.to(device=device) l_pooled = l_pooled.to(device=device) g_pooled = g_pooled.to(device=device) if dtype is not None: lg_out = lg_out.to(dtype=dtype) l_pooled = l_pooled.to(dtype=dtype) g_pooled = g_pooled.to(dtype=dtype) # t5xxl may be in another device (eg. cpu) if t5_tokens is None: t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) else: t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None if device is not None: t5_out = t5_out.to(device=device) if dtype is not None: t5_out = t5_out.to(dtype=dtype) # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) # used if other sd3 models is available r""" def get_sd3_configs(state_dict: Dict): # Important configuration values can be quickly determined by checking shapes in the source file # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) # prefix = "model.diffusion_model." prefix = "" patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 num_patches = state_dict[prefix + "pos_embed"].shape[1] pos_embed_max_size = round(math.sqrt(num_patches)) adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] context_shape = state_dict[prefix + "context_embedder.weight"].shape context_embedder_config = { "target": "torch.nn.Linear", "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, } return { "patch_size": patch_size, "depth": depth, "num_patches": num_patches, "pos_embed_max_size": pos_embed_max_size, "adm_in_channels": adm_in_channels, "context_embedder": context_embedder_config, } def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): "" Doesn't load state dict. "" sd3_configs = get_sd3_configs(state_dict) mmdit = sd3_models.MMDiT( input_size=None, pos_embed_max_size=sd3_configs["pos_embed_max_size"], patch_size=sd3_configs["patch_size"], in_channels=16, adm_in_channels=sd3_configs["adm_in_channels"], depth=sd3_configs["depth"], mlp_ratio=4, qk_norm=None, num_patches=sd3_configs["num_patches"], context_size=4096, attn_mode=attn_mode, ) return mmdit """ class ModelSamplingDiscreteFlow: """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" def __init__(self, shift=1.0): self.shift = shift timesteps = 1000 self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) @property def sigma_min(self): return self.sigmas[0] @property def sigma_max(self): return self.sigmas[-1] def timestep(self, sigma): return sigma * 1000 def sigma(self, timestep: torch.Tensor): timestep = timestep / 1000.0 if self.shift == 1.0: return timestep return self.shift * timestep / (1 + (self.shift - 1) * timestep) def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image # endregion