mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
310 lines
11 KiB
Python
310 lines
11 KiB
Python
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from transformers import CLIPTextModel, CLIPTextConfig
|
|
from diffusers import AutoencoderKL
|
|
from library import model_util
|
|
from library import sdxl_original_unet
|
|
|
|
|
|
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
|
|
|
# SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す
|
|
# logit_scaleはcheckpointの保存時に使用する
|
|
def convert_key(key):
|
|
# common conversion
|
|
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
|
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
|
|
|
if "resblocks" in key:
|
|
# resblocks conversion
|
|
key = key.replace(".resblocks.", ".layers.")
|
|
if ".ln_" in key:
|
|
key = key.replace(".ln_", ".layer_norm")
|
|
elif ".mlp." in key:
|
|
key = key.replace(".c_fc.", ".fc1.")
|
|
key = key.replace(".c_proj.", ".fc2.")
|
|
elif ".attn.out_proj" in key:
|
|
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
|
elif ".attn.in_proj" in key:
|
|
key = None # 特殊なので後で処理する
|
|
else:
|
|
raise ValueError(f"unexpected key in SD: {key}")
|
|
elif ".positional_embedding" in key:
|
|
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
|
elif ".text_projection" in key:
|
|
key = None # 後で処理する
|
|
elif ".logit_scale" in key:
|
|
key = None # 後で処理する
|
|
elif ".token_embedding" in key:
|
|
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
|
elif ".ln_final" in key:
|
|
key = key.replace(".ln_final", ".final_layer_norm")
|
|
return key
|
|
|
|
keys = list(checkpoint.keys())
|
|
new_sd = {}
|
|
for key in keys:
|
|
new_key = convert_key(key)
|
|
if new_key is None:
|
|
continue
|
|
new_sd[new_key] = checkpoint[key]
|
|
|
|
# attnの変換
|
|
for key in keys:
|
|
if ".resblocks" in key and ".attn.in_proj_" in key:
|
|
# 三つに分割
|
|
values = torch.chunk(checkpoint[key], 3)
|
|
|
|
key_suffix = ".weight" if "weight" in key else ".bias"
|
|
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
|
key_pfx = key_pfx.replace("_weight", "")
|
|
key_pfx = key_pfx.replace("_bias", "")
|
|
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
|
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
|
|
|
# original SD にはないので、position_idsを追加
|
|
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
|
|
|
# text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す
|
|
text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"]
|
|
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
|
|
|
return new_sd, text_projection, logit_scale
|
|
|
|
|
|
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
|
|
# model_type is reserved to future use
|
|
|
|
# Load the state dict
|
|
if model_util.is_safetensors(ckpt_path):
|
|
checkpoint = None
|
|
state_dict = load_file(ckpt_path, device=map_location)
|
|
epoch = None
|
|
global_step = None
|
|
else:
|
|
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
|
if "state_dict" in checkpoint:
|
|
state_dict = checkpoint["state_dict"]
|
|
epoch = checkpoint.get("epoch", 0)
|
|
global_step = checkpoint.get("global_step", 0)
|
|
else:
|
|
state_dict = checkpoint
|
|
epoch = 0
|
|
global_step = 0
|
|
checkpoint = None
|
|
|
|
# U-Net
|
|
print("building U-Net")
|
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
|
|
|
print("loading U-Net from checkpoint")
|
|
unet_sd = {}
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith("model.diffusion_model."):
|
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
|
info = unet.load_state_dict(unet_sd)
|
|
print("U-Net: ", info)
|
|
del unet_sd
|
|
|
|
# Text Encoders
|
|
print("building text encoders")
|
|
|
|
# Text Encoder 1 is same to SDXL
|
|
text_model1_cfg = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=768,
|
|
intermediate_size=3072,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
max_position_embeddings=77,
|
|
hidden_act="quick_gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=768,
|
|
# torch_dtype="float32",
|
|
# transformers_version="4.25.0.dev0",
|
|
)
|
|
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
|
|
|
# Text Encoder 2 is different from SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
|
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
|
text_model2_cfg = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=1280,
|
|
intermediate_size=5120,
|
|
num_hidden_layers=32,
|
|
num_attention_heads=20,
|
|
max_position_embeddings=77,
|
|
hidden_act="gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=1280,
|
|
# torch_dtype="float32",
|
|
# transformers_version="4.25.0.dev0",
|
|
)
|
|
text_model2 = CLIPTextModel._from_config(text_model2_cfg)
|
|
|
|
print("loading text encoders from checkpoint")
|
|
te1_sd = {}
|
|
te2_sd = {}
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith("conditioner.embedders.0.transformer."):
|
|
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
|
elif k.startswith("conditioner.embedders.1.model."):
|
|
te2_sd[k] = state_dict.pop(k)
|
|
|
|
info1 = text_model1.load_state_dict(te1_sd)
|
|
print("text encoder 1:", info1)
|
|
|
|
converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
|
info2 = text_model2.load_state_dict(converted_sd)
|
|
print("text encoder2:", info2)
|
|
|
|
# prepare vae
|
|
print("building VAE")
|
|
vae_config = model_util.create_vae_diffusers_config()
|
|
vae = AutoencoderKL(**vae_config) # .to(device)
|
|
|
|
print("loading VAE from checkpoint")
|
|
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
|
print("VAE:", info)
|
|
|
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
|
return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info
|
|
|
|
|
|
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale):
|
|
def convert_key(key):
|
|
# position_idsの除去
|
|
if ".position_ids" in key:
|
|
return None
|
|
|
|
# common
|
|
key = key.replace("text_model.encoder.", "transformer.")
|
|
key = key.replace("text_model.", "")
|
|
if "layers" in key:
|
|
# resblocks conversion
|
|
key = key.replace(".layers.", ".resblocks.")
|
|
if ".layer_norm" in key:
|
|
key = key.replace(".layer_norm", ".ln_")
|
|
elif ".mlp." in key:
|
|
key = key.replace(".fc1.", ".c_fc.")
|
|
key = key.replace(".fc2.", ".c_proj.")
|
|
elif ".self_attn.out_proj" in key:
|
|
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
|
elif ".self_attn." in key:
|
|
key = None # 特殊なので後で処理する
|
|
else:
|
|
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
|
elif ".position_embedding" in key:
|
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
|
elif ".token_embedding" in key:
|
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
|
elif "final_layer_norm" in key:
|
|
key = key.replace("final_layer_norm", "ln_final")
|
|
return key
|
|
|
|
keys = list(checkpoint.keys())
|
|
new_sd = {}
|
|
for key in keys:
|
|
new_key = convert_key(key)
|
|
if new_key is None:
|
|
continue
|
|
new_sd[new_key] = checkpoint[key]
|
|
|
|
# attnの変換
|
|
for key in keys:
|
|
if "layers" in key and "q_proj" in key:
|
|
# 三つを結合
|
|
key_q = key
|
|
key_k = key.replace("q_proj", "k_proj")
|
|
key_v = key.replace("q_proj", "v_proj")
|
|
|
|
value_q = checkpoint[key_q]
|
|
value_k = checkpoint[key_k]
|
|
value_v = checkpoint[key_v]
|
|
value = torch.cat([value_q, value_k, value_v])
|
|
|
|
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
|
new_sd[new_key] = value
|
|
|
|
new_sd["text_projection"] = text_projection
|
|
new_sd["logit_scale"] = logit_scale
|
|
|
|
return new_sd
|
|
|
|
|
|
def save_stable_diffusion_checkpoint(
|
|
output_file,
|
|
text_encoder1,
|
|
text_encoder2,
|
|
unet,
|
|
epochs,
|
|
steps,
|
|
ckpt_info,
|
|
vae,
|
|
text_projection,
|
|
logit_scale,
|
|
save_dtype=None,
|
|
):
|
|
state_dict = {}
|
|
|
|
def update_sd(prefix, sd):
|
|
for k, v in sd.items():
|
|
key = prefix + k
|
|
if save_dtype is not None:
|
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
|
state_dict[key] = v
|
|
|
|
# Convert the UNet model
|
|
update_sd("model.diffusion_model.", unet.state_dict())
|
|
|
|
# Convert the text encoders
|
|
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
|
|
|
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale)
|
|
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
|
|
|
# Convert the VAE
|
|
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
|
update_sd("first_stage_model.", vae_dict)
|
|
|
|
# Put together new checkpoint
|
|
key_count = len(state_dict.keys())
|
|
new_ckpt = {"state_dict": state_dict}
|
|
|
|
# epoch and global_step are sometimes not int
|
|
if ckpt_info is not None:
|
|
epochs += ckpt_info[0]
|
|
steps += ckpt_info[1]
|
|
|
|
new_ckpt["epoch"] = epochs
|
|
new_ckpt["global_step"] = steps
|
|
|
|
if model_util.is_safetensors(output_file):
|
|
save_file(state_dict, output_file)
|
|
else:
|
|
torch.save(new_ckpt, output_file)
|
|
|
|
return key_count
|