mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
use CLIPTextModelWithProjection
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection
|
||||
from diffusers import AutoencoderKL
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
@@ -13,7 +13,7 @@ MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||
|
||||
# SD2のと、基本的には同じ。text_projectionを後で使うので、それを追加で返す
|
||||
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
||||
# logit_scaleはcheckpointの保存時に使用する
|
||||
def convert_key(key):
|
||||
# common conversion
|
||||
@@ -37,7 +37,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
elif ".positional_embedding" in key:
|
||||
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||
elif ".text_projection" in key:
|
||||
key = None # 後で処理する
|
||||
key = key.replace("text_model.text_projection", "text_projection.weight")
|
||||
elif ".logit_scale" in key:
|
||||
key = None # 後で処理する
|
||||
elif ".token_embedding" in key:
|
||||
@@ -73,11 +73,10 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
|
||||
# text projection, logit_scale はDiffusersには含まれないが、後で必要になるので返す
|
||||
text_projection = checkpoint[SDXL_KEY_PREFIX + "text_projection"]
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
|
||||
|
||||
return new_sd, text_projection, logit_scale
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
@@ -164,7 +163,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
text_model2 = CLIPTextModel._from_config(text_model2_cfg)
|
||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
print("loading text encoders from checkpoint")
|
||||
te1_sd = {}
|
||||
@@ -178,7 +177,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
info1 = text_model1.load_state_dict(te1_sd)
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
converted_sd, text_projection, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
info2 = text_model2.load_state_dict(converted_sd)
|
||||
print("text encoder2:", info2)
|
||||
|
||||
@@ -193,10 +192,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
print("VAE:", info)
|
||||
|
||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||
return text_model1, text_model2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit_scale):
|
||||
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
if ".position_ids" in key:
|
||||
@@ -223,6 +222,8 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit
|
||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||
elif "text_projection" in key: # no dot in key
|
||||
key = key.replace("text_projection.weight", "text_projection")
|
||||
elif "final_layer_norm" in key:
|
||||
key = key.replace("final_layer_norm", "ln_final")
|
||||
return key
|
||||
@@ -252,7 +253,6 @@ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, text_projection, logit
|
||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||
new_sd[new_key] = value
|
||||
|
||||
new_sd["text_projection"] = text_projection
|
||||
new_sd["logit_scale"] = logit_scale
|
||||
|
||||
return new_sd
|
||||
@@ -267,7 +267,6 @@ def save_stable_diffusion_checkpoint(
|
||||
steps,
|
||||
ckpt_info,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
save_dtype=None,
|
||||
):
|
||||
@@ -286,7 +285,7 @@ def save_stable_diffusion_checkpoint(
|
||||
# Convert the text encoders
|
||||
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
||||
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), text_projection, logit_scale)
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
|
||||
# Convert the VAE
|
||||
|
||||
@@ -28,7 +28,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
@@ -46,7 +45,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||
@@ -64,7 +63,6 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||
@@ -74,7 +72,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
class WrapperTokenizer:
|
||||
@@ -138,7 +136,7 @@ def get_hidden_states(
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["pooler_output"]
|
||||
pool2 = enc_out["text_embeds"]
|
||||
|
||||
if args.max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
@@ -218,7 +216,6 @@ def save_sd_model_on_train_end(
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
@@ -232,7 +229,6 @@ def save_sd_model_on_train_end(
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
save_dtype,
|
||||
)
|
||||
@@ -262,7 +258,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
@@ -276,7 +271,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user