mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix unet cfg is different in saving diffuser model
This commit is contained in:
@@ -22,6 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4
|
||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||
UNET_PARAMS_CONTEXT_DIM = 768
|
||||
UNET_PARAMS_NUM_HEADS = 8
|
||||
UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
||||
|
||||
VAE_PARAMS_Z_CHANNELS = 4
|
||||
VAE_PARAMS_RESOLUTION = 256
|
||||
@@ -34,6 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
||||
# V2
|
||||
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||
V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
||||
|
||||
# Diffusersの設定を読み込むための参照モデル
|
||||
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
||||
@@ -207,13 +209,13 @@ def conv_attn_to_linear(checkpoint):
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def linear_transformer_to_conv(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||
if checkpoint[key].ndim == 2:
|
||||
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
||||
# def linear_transformer_to_conv(checkpoint):
|
||||
# keys = list(checkpoint.keys())
|
||||
# tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||
# for key in keys:
|
||||
# if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||
# if checkpoint[key].ndim == 2:
|
||||
# checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
@@ -357,9 +359,9 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
||||
if v2:
|
||||
linear_transformer_to_conv(new_checkpoint)
|
||||
# SDのv2では1*1のconv2dがlinearに変わっているが、Diffusers側も同じなので、変換不要
|
||||
# if v2:
|
||||
# linear_transformer_to_conv(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
@@ -500,6 +502,7 @@ def create_unet_diffusers_config(v2):
|
||||
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
||||
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
||||
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
||||
use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user