fix mismatch dtype

This commit is contained in:
Isotr0py
2023-07-28 13:47:54 +08:00
parent 315fbc11e5
commit fdb58b0b62
2 changed files with 3 additions and 2 deletions

View File

@@ -162,6 +162,7 @@ def _load_state_dict(model, state_dict, device, dtype=None):
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use # model_version is reserved for future use
# dtype is reserved for full_fp16/bf16 intergration
# Load the state dict # Load the state dict
if model_util.is_safetensors(ckpt_path): if model_util.is_safetensors(ckpt_path):
@@ -194,7 +195,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."): if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype) info = _load_state_dict(unet, unet_sd, device=map_location)
print("U-Net: ", info) print("U-Net: ", info)
# Text Encoders # Text Encoders

View File

@@ -99,7 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights(): with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet = sdxl_original_unet.SdxlUNet2DConditionModel()
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype) sdxl_model_util._load_state_dict(unet, state_dict, device=device)
print("U-Net converted to original U-Net") print("U-Net converted to original U-Net")
logit_scale = None logit_scale = None