Edit wrong file :<

This commit is contained in:
Hacker 17082006
2023-01-14 14:55:57 +07:00
parent a75fd3964a
commit 95ee349e2a
2 changed files with 1281 additions and 1073 deletions

View File

@@ -839,16 +839,10 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
]
if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, "cpu")
else:
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
checkpoint = None
checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path)
else torch.load(ckpt_path, map_location="cpu"))
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
if not "state_dict" in checkpoint: checkpoint = None
key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
@@ -1103,21 +1097,15 @@ def load_vae(vae_id, dtype):
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
@@ -1125,9 +1113,7 @@ def load_vae(vae_id, dtype):
vae_sd = sd
del sd
# Convert the VAE model.
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae

File diff suppressed because it is too large Load Diff