mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor model loading to catch error
This commit is contained in:
@@ -3,6 +3,7 @@ from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from typing import List
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
@@ -135,7 +136,31 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype):
|
||||
def _load_state_dict(model, state_dict, device, dtype=None):
|
||||
# dtype will use fp32 as default
|
||||
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
||||
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
||||
|
||||
# similar to model.load_state_dict()
|
||||
if not missing_keys and not unexpected_keys:
|
||||
for k in list(state_dict.keys()):
|
||||
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
||||
return '<All keys matched successfully>'
|
||||
else:
|
||||
error_msgs: List[str] = []
|
||||
if missing_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
if unexpected_keys:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
# model_version is reserved for future use
|
||||
|
||||
# Load the state dict
|
||||
@@ -165,13 +190,12 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
print("loading U-Net from checkpoint")
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
set_module_tensor_to_device(
|
||||
unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype
|
||||
)
|
||||
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
|
||||
# print("U-Net: ", info)
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = _load_state_dict(unet, unet_sd, device=map_location, dtype=dtype)
|
||||
print("U-Net: ", info)
|
||||
|
||||
# Text Encoders
|
||||
print("building text encoders")
|
||||
|
||||
@@ -5,7 +5,6 @@ import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
@@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
for k in list(state_dict.keys()):
|
||||
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype)
|
||||
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype)
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
|
||||
Reference in New Issue
Block a user