mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Not necessary to edit load_checkpoint_with_text_encoder_conversion
This commit is contained in:
@@ -839,10 +839,16 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||
]
|
||||
|
||||
checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path)
|
||||
else torch.load(ckpt_path, map_location="cpu"))
|
||||
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
if not "state_dict" in checkpoint: checkpoint = None
|
||||
if is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path, "cpu")
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
checkpoint = None
|
||||
|
||||
key_reps = []
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
@@ -858,6 +864,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
return checkpoint, state_dict
|
||||
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
|
||||
Reference in New Issue
Block a user