diff --git a/library/model_util.py b/library/model_util.py index bc824a12..96c7bbf2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -864,6 +864,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): return checkpoint, state_dict + # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) @@ -1088,6 +1089,11 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod VAE_PREFIX = "first_stage_model." +def convert_vae(vae_sd, vae_config): + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + def load_vae(vae_id, dtype): print(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): @@ -1103,34 +1109,26 @@ def load_vae(vae_id, dtype): # local vae_config = create_vae_diffusers_config() - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] + if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config) - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) + else torch.load(vae_id, map_location="cpu")) + vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model + + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae + + return convert_vae(vae_sd, vae_config) # endregion