mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Edit wrong file :<
This commit is contained in:
@@ -839,16 +839,10 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_safetensors(ckpt_path):
|
checkpoint = (load_file(ckpt_path, "cpu") if is_safetensors(ckpt_path)
|
||||||
checkpoint = None
|
else torch.load(ckpt_path, map_location="cpu"))
|
||||||
state_dict = load_file(ckpt_path, "cpu")
|
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
else:
|
if not "state_dict" in checkpoint: checkpoint = None
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
|
||||||
if "state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["state_dict"]
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
checkpoint = None
|
|
||||||
|
|
||||||
key_reps = []
|
key_reps = []
|
||||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||||
@@ -1103,21 +1097,15 @@ def load_vae(vae_id, dtype):
|
|||||||
# local
|
# local
|
||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
|
|
||||||
if vae_id.endswith(".bin"):
|
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||||
# SD 1.5 VAE on Huggingface
|
else torch.load(vae_id, map_location="cpu"))
|
||||||
vae_sd = torch.load(vae_id, map_location="cpu")
|
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||||
converted_vae_checkpoint = vae_sd
|
|
||||||
else:
|
|
||||||
# StableDiffusion
|
|
||||||
vae_model = torch.load(vae_id, map_location="cpu")
|
|
||||||
vae_sd = vae_model['state_dict']
|
|
||||||
|
|
||||||
# vae only or full model
|
full_model = False
|
||||||
full_model = False
|
for vae_key in vae_sd:
|
||||||
for vae_key in vae_sd:
|
if vae_key.startswith(VAE_PREFIX):
|
||||||
if vae_key.startswith(VAE_PREFIX):
|
full_model = True
|
||||||
full_model = True
|
break
|
||||||
break
|
|
||||||
if not full_model:
|
if not full_model:
|
||||||
sd = {}
|
sd = {}
|
||||||
for key, value in vae_sd.items():
|
for key, value in vae_sd.items():
|
||||||
@@ -1125,9 +1113,7 @@ def load_vae(vae_id, dtype):
|
|||||||
vae_sd = sd
|
vae_sd = sd
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
# Convert the VAE model.
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
return vae
|
return vae
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user