mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
refactor model loading to catch error
This commit is contained in:
@@ -5,7 +5,6 @@ import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
@@ -100,8 +99,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
for k in list(state_dict.keys()):
|
||||
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype)
|
||||
sdxl_model_util._load_state_dict(unet, state_dict, device=device, dtype=weight_dtype)
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
|
||||
Reference in New Issue
Block a user