Edit wrong file :<

This commit is contained in:
Hacker 17082006
2023-01-14 14:55:57 +07:00
parent a75fd3964a
commit 95ee349e2a
2 changed files with 1281 additions and 1073 deletions

View File

@@ -839,16 +839,10 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
] ]
if is_safetensors(ckpt_path): checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path)
checkpoint = None else torch.load(ckpt_path, map_location="cpu"))
state_dict = load_file(ckpt_path, "cpu") state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
else: if not "state_dict" in checkpoint: checkpoint = None
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
checkpoint = None
key_reps = [] key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
@@ -1103,21 +1097,15 @@ def load_vae(vae_id, dtype):
# local # local
vae_config = create_vae_diffusers_config() vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"): vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
# SD 1.5 VAE on Huggingface else torch.load(vae_id, map_location="cpu"))
vae_sd = torch.load(vae_id, map_location="cpu") vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
converted_vae_checkpoint = vae_sd
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
# vae only or full model full_model = False
full_model = False for vae_key in vae_sd:
for vae_key in vae_sd: if vae_key.startswith(VAE_PREFIX):
if vae_key.startswith(VAE_PREFIX): full_model = True
full_model = True break
break
if not full_model: if not full_model:
sd = {} sd = {}
for key, value in vae_sd.items(): for key, value in vae_sd.items():
@@ -1125,9 +1113,7 @@ def load_vae(vae_id, dtype):
vae_sd = sd vae_sd = sd
del sd del sd
# Convert the VAE model. converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
return vae return vae

File diff suppressed because it is too large Load Diff