refactor model loading to catch error

This commit is contained in:
Isotr0py
2023-07-28 13:10:38 +08:00
parent 272dd993e6
commit 315fbc11e5
2 changed files with 31 additions and 9 deletions

View File

@@ -5,7 +5,6 @@ import os
from typing import Optional
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
@@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
for k in list(state_dict.keys()):
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype)
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype)
print("U-Net converted to original U-Net")
logit_scale = None