mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Revert unet config, add option to convert script
This commit is contained in:
@@ -22,7 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4
|
|||||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||||
UNET_PARAMS_CONTEXT_DIM = 768
|
UNET_PARAMS_CONTEXT_DIM = 768
|
||||||
UNET_PARAMS_NUM_HEADS = 8
|
UNET_PARAMS_NUM_HEADS = 8
|
||||||
UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
||||||
|
|
||||||
VAE_PARAMS_Z_CHANNELS = 4
|
VAE_PARAMS_Z_CHANNELS = 4
|
||||||
VAE_PARAMS_RESOLUTION = 256
|
VAE_PARAMS_RESOLUTION = 256
|
||||||
@@ -35,7 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
|||||||
# V2
|
# V2
|
||||||
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||||
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||||
V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
||||||
|
|
||||||
# Diffusersの設定を読み込むための参照モデル
|
# Diffusersの設定を読み込むための参照モデル
|
||||||
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
||||||
@@ -209,13 +209,13 @@ def conv_attn_to_linear(checkpoint):
|
|||||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
# def linear_transformer_to_conv(checkpoint):
|
def linear_transformer_to_conv(checkpoint):
|
||||||
# keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
# tf_keys = ["proj_in.weight", "proj_out.weight"]
|
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||||
# for key in keys:
|
for key in keys:
|
||||||
# if ".".join(key.split(".")[-2:]) in tf_keys:
|
if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||||
# if checkpoint[key].ndim == 2:
|
if checkpoint[key].ndim == 2:
|
||||||
# checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||||
@@ -359,9 +359,10 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
|||||||
|
|
||||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
# SDのv2では1*1のconv2dがlinearに変わっているが、Diffusers側も同じなので、変換不要
|
# SDのv2では1*1のconv2dがlinearに変わっている
|
||||||
# if v2:
|
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
||||||
# linear_transformer_to_conv(new_checkpoint)
|
if v2 and not config.get('use_linear_projection', False):
|
||||||
|
linear_transformer_to_conv(new_checkpoint)
|
||||||
|
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
@@ -470,7 +471,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def create_unet_diffusers_config(v2):
|
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
||||||
"""
|
"""
|
||||||
Creates a config for the diffusers based on the config of the LDM model.
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
"""
|
"""
|
||||||
@@ -502,8 +503,10 @@ def create_unet_diffusers_config(v2):
|
|||||||
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
||||||
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
||||||
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
||||||
use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
||||||
)
|
)
|
||||||
|
if v2 and use_linear_projection_in_v2:
|
||||||
|
config["use_linear_projection"] = True
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -849,11 +852,11 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
|||||||
|
|
||||||
|
|
||||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
|
||||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config(v2)
|
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config).to(device)
|
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def convert(args):
|
|||||||
|
|
||||||
if is_load_ckpt:
|
if is_load_ckpt:
|
||||||
v2_model = args.v2
|
v2_model = args.v2
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
||||||
else:
|
else:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
||||||
@@ -76,6 +76,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fp16",
|
"--fp16",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user