mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Dict, Optional, Union
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import safetensors
|
||||
from safetensors.torch import load_file
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
@@ -19,18 +22,61 @@ from library import sdxl_model_util
|
||||
|
||||
# region models
|
||||
|
||||
# TODO remove dependency on flux_utils
|
||||
from library.utils import load_safetensors
|
||||
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
||||
|
||||
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
|
||||
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
||||
logger.info(f"Analyzing state dict state...")
|
||||
|
||||
# analyze configs
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
|
||||
|
||||
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
||||
x_block_self_attn_layers = []
|
||||
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
|
||||
for key in list(state_dict.keys()):
|
||||
m = re_attn.match(key)
|
||||
if m:
|
||||
x_block_self_attn_layers.append(int(m.group(1)))
|
||||
|
||||
assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"
|
||||
|
||||
context_embedder_in_features = context_shape[1]
|
||||
context_embedder_out_features = context_shape[0]
|
||||
|
||||
# only supports 3-5-large and 3-medium
|
||||
if qk_norm is not None:
|
||||
model_type = "3-5-large"
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=dvc)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
model_type = "3-medium"
|
||||
|
||||
params = sd3_models.SD3Params(
|
||||
patch_size=patch_size,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
adm_in_channels=adm_in_channels,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
context_embedder_in_features=context_embedder_in_features,
|
||||
context_embedder_out_features=context_embedder_out_features,
|
||||
model_type=model_type,
|
||||
)
|
||||
logger.info(f"Analyzed state dict state: {params}")
|
||||
return params
|
||||
|
||||
|
||||
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
|
||||
def load_mmdit(
|
||||
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
|
||||
) -> sd3_models.MMDiT:
|
||||
mmdit_sd = {}
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
@@ -40,8 +86,9 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
params = analyze_state_dict_state(mmdit_sd)
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
|
||||
@@ -50,20 +97,14 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
state_dict: Dict,
|
||||
clip_l_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_l_sd = None
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
else:
|
||||
if clip_l_path is None:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
@@ -72,34 +113,58 @@ def load_clip_l(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_l_path is None:
|
||||
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_l
|
||||
logger.info("Building CLIP-L")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_l_sd is None:
|
||||
clip_l = None
|
||||
else:
|
||||
logger.info("Building ClipL")
|
||||
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded ClipL: {info}")
|
||||
clip_l.set_attn_mode(attn_mode)
|
||||
return clip_l
|
||||
logger.info(f"Loading state dict from {clip_l_path}")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
if "text_projection.weight" not in clip_l_sd:
|
||||
logger.info("Adding text_projection.weight to clip_l_sd")
|
||||
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
|
||||
|
||||
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-L: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_clip_g(
|
||||
state_dict: Dict,
|
||||
clip_g_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_g_sd = None
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
else:
|
||||
if state_dict is not None:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
@@ -108,34 +173,53 @@ def load_clip_g(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_g_path is None:
|
||||
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_g
|
||||
logger.info("Building CLIP-G")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_g_sd is None:
|
||||
clip_g = None
|
||||
else:
|
||||
logger.info("Building ClipG")
|
||||
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded ClipG: {info}")
|
||||
clip_g.set_attn_mode(attn_mode)
|
||||
return clip_g
|
||||
logger.info(f"Loading state dict from {clip_g_path}")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-G: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
state_dict: Dict,
|
||||
t5xxl_path: Optional[str],
|
||||
attn_mode: str,
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
t5xxl_sd = None
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
if state_dict is not None:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
@@ -144,29 +228,19 @@ def load_t5xxl(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif t5xxl_path is None:
|
||||
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
|
||||
return None
|
||||
|
||||
if t5xxl_sd is None:
|
||||
t5xxl = None
|
||||
else:
|
||||
logger.info("Building T5XXL")
|
||||
|
||||
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
|
||||
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
|
||||
t5xxl.to(dtype=dtype)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded T5XXL: {info}")
|
||||
t5xxl.set_attn_mode(attn_mode)
|
||||
return t5xxl
|
||||
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
|
||||
|
||||
|
||||
def load_vae(
|
||||
state_dict: Dict,
|
||||
vae_path: Optional[str],
|
||||
vae_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
@@ -181,299 +255,15 @@ def load_vae(
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
vae = sd3_models.SDVAE(vae_dtype, device)
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype)
|
||||
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
|
||||
return vae
|
||||
|
||||
|
||||
def load_models(
|
||||
ckpt_path: str,
|
||||
clip_l_path: str,
|
||||
clip_g_path: str,
|
||||
t5xxl_path: str,
|
||||
vae_path: str,
|
||||
attn_mode: str,
|
||||
device: Union[str, torch.device],
|
||||
weight_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
disable_mmap: bool = False,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
t5xxl_device: Optional[Union[str, torch.device]] = None,
|
||||
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
vae_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
):
|
||||
"""
|
||||
Load SD3 models from checkpoint files.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the SD3 checkpoint file.
|
||||
clip_l_path: Path to the clip_l checkpoint file.
|
||||
clip_g_path: Path to the clip_g checkpoint file.
|
||||
t5xxl_path: Path to the t5xxl checkpoint file.
|
||||
vae_path: Path to the VAE checkpoint file.
|
||||
attn_mode: Attention mode for MMDiT model.
|
||||
device: Device for MMDiT model.
|
||||
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
|
||||
disable_mmap: Disable memory mapping when loading state dict.
|
||||
clip_dtype: Dtype for Clip models, or None to use default dtype.
|
||||
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
|
||||
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
|
||||
vae_dtype: Dtype for VAE model, or None to use default dtype.
|
||||
|
||||
Returns:
|
||||
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
|
||||
"""
|
||||
|
||||
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
|
||||
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
|
||||
# Therefore, we need clip_dtype and t5xxl_dtype.
|
||||
|
||||
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=dvc)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
t5xxl_device = t5xxl_device or device
|
||||
clip_dtype = clip_dtype or weight_dtype or torch.float32
|
||||
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
|
||||
vae_dtype = vae_dtype or weight_dtype or torch.float32
|
||||
|
||||
logger.info(f"Loading SD3 models from {ckpt_path}...")
|
||||
state_dict = load_state_dict(ckpt_path)
|
||||
|
||||
# load clip_l
|
||||
clip_l_sd = None
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||
clip_l_sd = load_state_dict(clip_l_path)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
clip_l_sd = {}
|
||||
prefix = "text_encoders.clip_l."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load clip_g
|
||||
clip_g_sd = None
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||
clip_g_sd = load_state_dict(clip_g_path)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
clip_g_sd = {}
|
||||
prefix = "text_encoders.clip_g."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load t5xxl
|
||||
t5xxl_sd = None
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||
t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
t5xxl_sd = {}
|
||||
prefix = "text_encoders.t5xxl."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# MMDiT and VAE
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE from {vae_path}...")
|
||||
vae_sd = load_state_dict(vae_path)
|
||||
else:
|
||||
# remove prefix "first_stage_model."
|
||||
vae_sd = {}
|
||||
vae_prefix = "first_stage_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(vae_prefix):
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(mmdit_prefix):
|
||||
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||
else:
|
||||
state_dict.pop(k) # remove other keys
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
|
||||
# load ClipG and ClipL
|
||||
if clip_l_sd is None:
|
||||
clip_l = None
|
||||
else:
|
||||
logger.info("Building ClipL")
|
||||
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded ClipL: {info}")
|
||||
clip_l.set_attn_mode(attn_mode)
|
||||
|
||||
if clip_g_sd is None:
|
||||
clip_g = None
|
||||
else:
|
||||
logger.info("Building ClipG")
|
||||
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded ClipG: {info}")
|
||||
clip_g.set_attn_mode(attn_mode)
|
||||
|
||||
# load T5XXL
|
||||
if t5xxl_sd is None:
|
||||
t5xxl = None
|
||||
else:
|
||||
logger.info("Building T5XXL")
|
||||
t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded T5XXL: {info}")
|
||||
t5xxl.set_attn_mode(attn_mode)
|
||||
|
||||
# load VAE
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype)
|
||||
|
||||
return mmdit, clip_l, clip_g, t5xxl, vae
|
||||
|
||||
|
||||
# endregion
|
||||
# region utils
|
||||
|
||||
|
||||
def get_cond(
|
||||
prompt: str,
|
||||
tokenizer: sd3_models.SD3Tokenizer,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
||||
print(t5_tokens)
|
||||
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def get_cond_from_tokens(
|
||||
l_tokens,
|
||||
g_tokens,
|
||||
t5_tokens,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
|
||||
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
if device is not None:
|
||||
lg_out = lg_out.to(device=device)
|
||||
l_pooled = l_pooled.to(device=device)
|
||||
g_pooled = g_pooled.to(device=device)
|
||||
if dtype is not None:
|
||||
lg_out = lg_out.to(dtype=dtype)
|
||||
l_pooled = l_pooled.to(dtype=dtype)
|
||||
g_pooled = g_pooled.to(dtype=dtype)
|
||||
|
||||
# t5xxl may be in another device (eg. cpu)
|
||||
if t5_tokens is None:
|
||||
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
||||
else:
|
||||
t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
|
||||
if device is not None:
|
||||
t5_out = t5_out.to(device=device)
|
||||
if dtype is not None:
|
||||
t5_out = t5_out.to(dtype=dtype)
|
||||
|
||||
# return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
|
||||
# used if other sd3 models is available
|
||||
r"""
|
||||
def get_sd3_configs(state_dict: Dict):
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
# prefix = "model.diffusion_model."
|
||||
prefix = ""
|
||||
|
||||
patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[prefix + "pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[prefix + "context_embedder.weight"].shape
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {"in_features": context_shape[1], "out_features": context_shape[0]},
|
||||
}
|
||||
return {
|
||||
"patch_size": patch_size,
|
||||
"depth": depth,
|
||||
"num_patches": num_patches,
|
||||
"pos_embed_max_size": pos_embed_max_size,
|
||||
"adm_in_channels": adm_in_channels,
|
||||
"context_embedder": context_embedder_config,
|
||||
}
|
||||
|
||||
|
||||
def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"):
|
||||
""
|
||||
Doesn't load state dict.
|
||||
""
|
||||
sd3_configs = get_sd3_configs(state_dict)
|
||||
|
||||
mmdit = sd3_models.MMDiT(
|
||||
input_size=None,
|
||||
pos_embed_max_size=sd3_configs["pos_embed_max_size"],
|
||||
patch_size=sd3_configs["patch_size"],
|
||||
in_channels=16,
|
||||
adm_in_channels=sd3_configs["adm_in_channels"],
|
||||
depth=sd3_configs["depth"],
|
||||
mlp_ratio=4,
|
||||
qk_norm=None,
|
||||
num_patches=sd3_configs["num_patches"],
|
||||
context_size=4096,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
return mmdit
|
||||
"""
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow:
|
||||
@@ -509,6 +299,3 @@ class ModelSamplingDiscreteFlow:
|
||||
# assert max_denoise is False, "max_denoise not implemented"
|
||||
# max_denoise is always True, I'm not sure why it's there
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user