.bin file don't need to be checked

This commit is contained in:
Hacker 17082006
2023-01-14 15:23:46 +07:00
parent b3d3f0c8ac
commit dfeadf9e52

View File

@@ -1089,6 +1089,11 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
VAE_PREFIX = "first_stage_model."
def convert_vae(vae_sd, vae_config):
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
@@ -1104,6 +1109,8 @@ def load_vae(vae_id, dtype):
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
@@ -1120,10 +1127,8 @@ def load_vae(vae_id, dtype):
vae_sd = sd
del sd
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
return convert_vae(vae_sd, vae_config)
# endregion