mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
.bin file don't need to be checked
This commit is contained in:
@@ -1089,6 +1089,11 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
||||
VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def convert_vae(vae_sd, vae_config):
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
print(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
@@ -1104,6 +1109,8 @@ def load_vae(vae_id, dtype):
|
||||
# local
|
||||
vae_config = create_vae_diffusers_config()
|
||||
|
||||
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
|
||||
|
||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||
else torch.load(vae_id, map_location="cpu"))
|
||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||
@@ -1120,10 +1127,8 @@ def load_vae(vae_id, dtype):
|
||||
vae_sd = sd
|
||||
del sd
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
return convert_vae(vae_sd, vae_config)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user