fix device not specified in merge_lora.py

This commit is contained in:
Kohya S
2023-04-01 09:15:57 +09:00
parent 1cd07770a4
commit 4627b389ff

View File

@@ -812,7 +812,7 @@ def is_safetensors(path):
return os.path.splitext(path)[1].lower() == ".safetensors"
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device):
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),