mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix unet cfg is different in saving diffuser model
This commit is contained in:
@@ -22,6 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4
|
|||||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||||
UNET_PARAMS_CONTEXT_DIM = 768
|
UNET_PARAMS_CONTEXT_DIM = 768
|
||||||
UNET_PARAMS_NUM_HEADS = 8
|
UNET_PARAMS_NUM_HEADS = 8
|
||||||
|
UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
||||||
|
|
||||||
VAE_PARAMS_Z_CHANNELS = 4
|
VAE_PARAMS_Z_CHANNELS = 4
|
||||||
VAE_PARAMS_RESOLUTION = 256
|
VAE_PARAMS_RESOLUTION = 256
|
||||||
@@ -34,6 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
|||||||
# V2
|
# V2
|
||||||
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||||
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||||
|
V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
||||||
|
|
||||||
# Diffusersの設定を読み込むための参照モデル
|
# Diffusersの設定を読み込むための参照モデル
|
||||||
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
||||||
@@ -207,13 +209,13 @@ def conv_attn_to_linear(checkpoint):
|
|||||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
def linear_transformer_to_conv(checkpoint):
|
# def linear_transformer_to_conv(checkpoint):
|
||||||
keys = list(checkpoint.keys())
|
# keys = list(checkpoint.keys())
|
||||||
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
# tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||||
for key in keys:
|
# for key in keys:
|
||||||
if ".".join(key.split(".")[-2:]) in tf_keys:
|
# if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||||
if checkpoint[key].ndim == 2:
|
# if checkpoint[key].ndim == 2:
|
||||||
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
# checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||||
@@ -357,9 +359,9 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
|||||||
|
|
||||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
# SDのv2では1*1のconv2dがlinearに変わっているが、Diffusers側も同じなので、変換不要
|
||||||
if v2:
|
# if v2:
|
||||||
linear_transformer_to_conv(new_checkpoint)
|
# linear_transformer_to_conv(new_checkpoint)
|
||||||
|
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
@@ -500,6 +502,7 @@ def create_unet_diffusers_config(v2):
|
|||||||
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
||||||
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
||||||
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
||||||
|
use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ def convert(args):
|
|||||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||||
|
|
||||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||||
assert (
|
# assert (
|
||||||
is_save_ckpt or args.reference_model is not None
|
# is_save_ckpt or args.reference_model is not None
|
||||||
), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||||
@@ -61,7 +61,7 @@ def convert(args):
|
|||||||
)
|
)
|
||||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
else:
|
else:
|
||||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
|
||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||||
)
|
)
|
||||||
@@ -100,7 +100,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--reference_model",
|
"--reference_model",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要",
|
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_safetensors",
|
"--use_safetensors",
|
||||||
|
|||||||
Reference in New Issue
Block a user