Merge pull request #70 from kohya-ss/dev

Fix loading VAE failed in some model and with .safetensors
This commit is contained in:
Kohya S
2023-01-14 21:26:41 +09:00
committed by GitHub

View File

@@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd
@@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info)
print("loading vae:", info)
# convert text_model
if v2:
@@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype):
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model
full_model = False
@@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype):
vae.load_state_dict(converted_vae_checkpoint)
return vae
# endregion