refactor model loading to catch error

This commit is contained in:
Isotr0py
2023-07-28 13:10:38 +08:00
parent 272dd993e6
commit 315fbc11e5
2 changed files with 31 additions and 9 deletions

View File

@@ -3,6 +3,7 @@ from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
@@ -135,7 +136,31 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
return new_sd, logit_scale
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype):
def _load_state_dict(model, state_dict, device, dtype=None):
# dtype will use fp32 as default
missing_keys = list(model.state_dict().keys() - state_dict.keys())
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
# similar to model.load_state_dict()
if not missing_keys and not unexpected_keys:
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return '<All keys matched successfully>'
else:
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if unexpected_keys:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use
# Load the state dict
@@ -165,13 +190,12 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
set_module_tensor_to_device(
unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype
)
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
# print("U-Net: ", info)
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)
# Text Encoders
print("building text encoders")

View File

@@ -5,7 +5,6 @@ import os
from typing import Optional
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
@@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
for k in list(state_dict.keys()):
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype)
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype)
print("U-Net converted to original U-Net")
logit_scale = None