From b3d3f0c8ac74357313c844237d3f1dfe2371d4fc Mon Sep 17 00:00:00 2001 From: Hacker 17082006 Date: Sat, 14 Jan 2023 15:07:56 +0700 Subject: [PATCH] Not necessary to edit load_checkpoint_with_text_encoder_conversion --- library/model_util.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index 7d16338d..18489266 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -839,10 +839,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ] - checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path) - else torch.load(ckpt_path, map_location="cpu")) - state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint - if not "state_dict" in checkpoint: checkpoint = None + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, "cpu") + else: + checkpoint = torch.load(ckpt_path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None key_reps = [] for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: @@ -858,6 +864,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): return checkpoint, state_dict + # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)