diff --git a/library/model_util.py b/library/model_util.py index 18489266..96c7bbf2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1089,6 +1089,11 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod VAE_PREFIX = "first_stage_model." +def convert_vae(vae_sd, vae_config): + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + def load_vae(vae_id, dtype): print(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): @@ -1104,6 +1109,8 @@ def load_vae(vae_id, dtype): # local vae_config = create_vae_diffusers_config() + if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config) + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")) vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model @@ -1120,10 +1127,8 @@ def load_vae(vae_id, dtype): vae_sd = sd del sd - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae + + return convert_vae(vae_sd, vae_config) # endregion