From b1dffe8d9ae1c02a06e8871a844c42d6729623ce Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 00:11:11 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB=E3=83=AD?= =?UTF-8?q?=E3=83=BC=E3=83=89=E3=81=8C=E3=81=A7=E3=81=8D=E3=81=AA=E3=81=84?= =?UTF-8?q?=E3=83=90=E3=82=B0=E4=BF=AE=E6=AD=A3(Exception:=20device=20cuda?= =?UTF-8?q?=20is=20invalid)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/model_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index e227ced8..f3f236af 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -831,7 +831,7 @@ def is_safetensors(path): return os.path.splitext(path)[1].lower() == '.safetensors' -def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), @@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2)