Not necessary to edit load_checkpoint_with_text_encoder_conversion

This commit is contained in:
Hacker 17082006
2023-01-14 15:07:56 +07:00
parent 4fe1dd6a1c
commit b3d3f0c8ac

View File

@@ -839,10 +839,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
]
checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path)
else torch.load(ckpt_path, map_location="cpu"))
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
if not "state_dict" in checkpoint: checkpoint = None
if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, "cpu")
else:
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
checkpoint = None
key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
@@ -858,6 +864,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)