mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
move convert_vae to inline, restore comments
This commit is contained in:
@@ -864,7 +864,6 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
return checkpoint, state_dict
|
return checkpoint, state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||||
@@ -887,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
print("loadint vae:", info)
|
print("loading vae:", info)
|
||||||
|
|
||||||
# convert text_model
|
# convert text_model
|
||||||
if v2:
|
if v2:
|
||||||
@@ -1089,11 +1088,6 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
|||||||
VAE_PREFIX = "first_stage_model."
|
VAE_PREFIX = "first_stage_model."
|
||||||
|
|
||||||
|
|
||||||
def convert_vae(vae_sd, vae_config):
|
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
|
||||||
|
|
||||||
def load_vae(vae_id, dtype):
|
def load_vae(vae_id, dtype):
|
||||||
print(f"load VAE: {vae_id}")
|
print(f"load VAE: {vae_id}")
|
||||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||||
@@ -1109,27 +1103,34 @@ def load_vae(vae_id, dtype):
|
|||||||
# local
|
# local
|
||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
|
|
||||||
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
|
if vae_id.endswith(".bin"):
|
||||||
|
# SD 1.5 VAE on Huggingface
|
||||||
|
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
||||||
|
else:
|
||||||
|
# StableDiffusion
|
||||||
|
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||||
|
else torch.load(vae_id, map_location="cpu"))
|
||||||
|
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||||
|
|
||||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
# vae only or full model
|
||||||
else torch.load(vae_id, map_location="cpu"))
|
full_model = False
|
||||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
for vae_key in vae_sd:
|
||||||
|
if vae_key.startswith(VAE_PREFIX):
|
||||||
|
full_model = True
|
||||||
|
break
|
||||||
|
if not full_model:
|
||||||
|
sd = {}
|
||||||
|
for key, value in vae_sd.items():
|
||||||
|
sd[VAE_PREFIX + key] = value
|
||||||
|
vae_sd = sd
|
||||||
|
del sd
|
||||||
|
|
||||||
full_model = False
|
# Convert the VAE model.
|
||||||
for vae_key in vae_sd:
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||||
if vae_key.startswith(VAE_PREFIX):
|
|
||||||
full_model = True
|
|
||||||
break
|
|
||||||
if not full_model:
|
|
||||||
sd = {}
|
|
||||||
for key, value in vae_sd.items():
|
|
||||||
sd[VAE_PREFIX + key] = value
|
|
||||||
vae_sd = sd
|
|
||||||
del sd
|
|
||||||
|
|
||||||
|
|
||||||
return convert_vae(vae_sd, vae_config)
|
|
||||||
|
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
return vae
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user