Revert unet config, add option to convert script

This commit is contained in:
ykume
2023-05-03 16:05:15 +09:00
parent 1cba447102
commit 758a1e7f66
2 changed files with 23 additions and 17 deletions

View File

@@ -22,7 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4
UNET_PARAMS_NUM_RES_BLOCKS = 2
UNET_PARAMS_CONTEXT_DIM = 768
UNET_PARAMS_NUM_HEADS = 8
UNET_PARAMS_USE_LINEAR_PROJECTION = False
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
VAE_PARAMS_Z_CHANNELS = 4
VAE_PARAMS_RESOLUTION = 256
@@ -35,7 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
# V2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024
V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
# Diffusersの設定を読み込むための参照モデル
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
@@ -209,13 +209,13 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
# def linear_transformer_to_conv(checkpoint):
# keys = list(checkpoint.keys())
# tf_keys = ["proj_in.weight", "proj_out.weight"]
# for key in keys:
# if ".".join(key.split(".")[-2:]) in tf_keys:
# if checkpoint[key].ndim == 2:
# checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
def linear_transformer_to_conv(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim == 2:
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
@@ -359,9 +359,10 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
new_checkpoint[new_path] = unet_state_dict[old_path]
# SDのv2では1*1のconv2dがlinearに変わっているが、Diffusers側も同じなので、変換不要
# if v2:
# linear_transformer_to_conv(new_checkpoint)
# SDのv2では1*1のconv2dがlinearに変わっている
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
if v2 and not config.get('use_linear_projection', False):
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint
@@ -470,7 +471,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def create_unet_diffusers_config(v2):
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
@@ -502,8 +503,10 @@ def create_unet_diffusers_config(v2):
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
)
if v2 and use_linear_projection_in_v2:
config["use_linear_projection"] = True
return config
@@ -849,11 +852,11 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(v2)
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
unet = UNet2DConditionModel(**unet_config).to(device)