diff --git a/library/model_util.py b/library/model_util.py index bc824a12..6a1e656a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): del new_sd[ANOTHER_POSITION_IDS_KEY] else: position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd @@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): vae = AutoencoderKL(**vae_config) info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) + print("loading vae:", info) # convert text_model if v2: @@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype): if vae_id.endswith(".bin"): # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") else: # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) + else torch.load(vae_id, map_location="cpu")) + vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model # vae only or full model full_model = False @@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype): vae.load_state_dict(converted_vae_checkpoint) return vae - # endregion