From 758a1e7f666a9c081cee9702c9c5288c6607a59e Mon Sep 17 00:00:00 2001 From: ykume Date: Wed, 3 May 2023 16:05:15 +0900 Subject: [PATCH] Revert unet config, add option to convert script --- library/model_util.py | 35 +++++++++++++----------- tools/convert_diffusers20_original_sd.py | 5 +++- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index 8eea76da..26f72235 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -22,7 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4 UNET_PARAMS_NUM_RES_BLOCKS = 2 UNET_PARAMS_CONTEXT_DIM = 768 UNET_PARAMS_NUM_HEADS = 8 -UNET_PARAMS_USE_LINEAR_PROJECTION = False +# UNET_PARAMS_USE_LINEAR_PROJECTION = False VAE_PARAMS_Z_CHANNELS = 4 VAE_PARAMS_RESOLUTION = 256 @@ -35,7 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2 # V2 V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] V2_UNET_PARAMS_CONTEXT_DIM = 1024 -V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True # Diffusersの設定を読み込むための参照モデル DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" @@ -209,13 +209,13 @@ def conv_attn_to_linear(checkpoint): checkpoint[key] = checkpoint[key][:, :, 0] -# def linear_transformer_to_conv(checkpoint): -# keys = list(checkpoint.keys()) -# tf_keys = ["proj_in.weight", "proj_out.weight"] -# for key in keys: -# if ".".join(key.split(".")[-2:]) in tf_keys: -# if checkpoint[key].ndim == 2: -# checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) def convert_ldm_unet_checkpoint(v2, checkpoint, config): @@ -359,9 +359,10 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): new_checkpoint[new_path] = unet_state_dict[old_path] - # SDのv2では1*1のconv2dがlinearに変わっているが、Diffusers側も同じなので、変換不要 - # if v2: - # linear_transformer_to_conv(new_checkpoint) + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get('use_linear_projection', False): + linear_transformer_to_conv(new_checkpoint) return new_checkpoint @@ -470,7 +471,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def create_unet_diffusers_config(v2): +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): """ Creates a config for the diffusers based on the config of the LDM model. """ @@ -502,8 +503,10 @@ def create_unet_diffusers_config(v2): layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True return config @@ -849,11 +852,11 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) + unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2) converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) unet = UNet2DConditionModel(**unet_config).to(device) diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 130eff1f..b9365b51 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -34,7 +34,7 @@ def convert(args): if is_load_ckpt: v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection) else: pipe = StableDiffusionPipeline.from_pretrained( args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None @@ -76,6 +76,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" ) + parser.add_argument( + "--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)" + ) parser.add_argument( "--fp16", action="store_true",