move convert_vae to inline, restore comments

This commit is contained in:
Kohya S
2023-01-14 21:24:09 +09:00
parent 199a3cbae4
commit 61ec60a893

View File

@@ -864,7 +864,6 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
return checkpoint, state_dict return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -887,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint) info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info) print("loading vae:", info)
# convert text_model # convert text_model
if v2: if v2:
@@ -1089,11 +1088,6 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
VAE_PREFIX = "first_stage_model." VAE_PREFIX = "first_stage_model."
def convert_vae(vae_sd, vae_config):
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
def load_vae(vae_id, dtype): def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}") print(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id): if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
@@ -1109,27 +1103,34 @@ def load_vae(vae_id, dtype):
# local # local
vae_config = create_vae_diffusers_config() vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config) if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else:
# StableDiffusion
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) # vae only or full model
else torch.load(vae_id, map_location="cpu")) full_model = False
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
full_model = False # Convert the VAE model.
for vae_key in vae_sd: converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
return convert_vae(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
# endregion # endregion