mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
.bin file don't need to be checked
This commit is contained in:
@@ -1089,6 +1089,11 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
|||||||
VAE_PREFIX = "first_stage_model."
|
VAE_PREFIX = "first_stage_model."
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vae(vae_sd, vae_config):
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
||||||
def load_vae(vae_id, dtype):
|
def load_vae(vae_id, dtype):
|
||||||
print(f"load VAE: {vae_id}")
|
print(f"load VAE: {vae_id}")
|
||||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||||
@@ -1104,6 +1109,8 @@ def load_vae(vae_id, dtype):
|
|||||||
# local
|
# local
|
||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
|
|
||||||
|
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
|
||||||
|
|
||||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||||
else torch.load(vae_id, map_location="cpu"))
|
else torch.load(vae_id, map_location="cpu"))
|
||||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||||
@@ -1120,10 +1127,8 @@ def load_vae(vae_id, dtype):
|
|||||||
vae_sd = sd
|
vae_sd = sd
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
return convert_vae(vae_sd, vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
|
||||||
return vae
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
Reference in New Issue
Block a user