mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix mismatch dtype
This commit is contained in:
@@ -162,6 +162,7 @@ def _load_state_dict(model, state_dict, device, dtype=None):
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
# model_version is reserved for future use
|
||||
# dtype is reserved for full_fp16/bf16 intergration
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
@@ -194,7 +195,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype)
|
||||
info = _load_state_dict(unet, unet_sd, device=map_location)
|
||||
print("U-Net: ", info)
|
||||
|
||||
# Text Encoders
|
||||
|
||||
@@ -99,7 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype)
|
||||
sdxl_model_util._load_state_dict(unet, state_dict, device=device)
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
|
||||
Reference in New Issue
Block a user