init unet with empty weights

This commit is contained in:
Isotr0py
2023-07-23 13:17:11 +08:00
parent d1864e2430
commit bb167f94ca
2 changed files with 13 additions and 9 deletions

View File

@@ -1,4 +1,6 @@
import torch import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
@@ -156,16 +158,15 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
# U-Net # U-Net
print("building U-Net") print("building U-Net")
unet = sdxl_original_unet.SdxlUNet2DConditionModel() with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint") print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."): if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k))
info = unet.load_state_dict(unet_sd) # TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
print("U-Net: ", info) # print("U-Net: ", info)
del unet_sd
# Text Encoders # Text Encoders
print("building text encoders") print("building text encoders")

View File

@@ -5,6 +5,8 @@ import os
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any from typing import Any
import torch import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from tqdm import tqdm from tqdm import tqdm
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import open_clip import open_clip
@@ -92,10 +94,11 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
del pipe del pipe
# Diffusers U-Net to original U-Net # Diffusers U-Net to original U-Net
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
original_unet.load_state_dict(state_dict) with init_empty_weights():
unet = original_unet unet = sdxl_original_unet.SdxlUNet2DConditionModel()
for k in list(state_dict.keys()):
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))
print("U-Net converted to original U-Net") print("U-Net converted to original U-Net")
logit_scale = None logit_scale = None