mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
init unet with empty weights
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
||||||
@@ -156,16 +158,15 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
|
|
||||||
# U-Net
|
# U-Net
|
||||||
print("building U-Net")
|
print("building U-Net")
|
||||||
|
with init_empty_weights():
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
print("loading U-Net from checkpoint")
|
print("loading U-Net from checkpoint")
|
||||||
unet_sd = {}
|
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
if k.startswith("model.diffusion_model."):
|
if k.startswith("model.diffusion_model."):
|
||||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k))
|
||||||
info = unet.load_state_dict(unet_sd)
|
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
|
||||||
print("U-Net: ", info)
|
# print("U-Net: ", info)
|
||||||
del unet_sd
|
|
||||||
|
|
||||||
# Text Encoders
|
# Text Encoders
|
||||||
print("building text encoders")
|
print("building text encoders")
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import os
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
import open_clip
|
import open_clip
|
||||||
@@ -92,10 +94,11 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
del pipe
|
del pipe
|
||||||
|
|
||||||
# Diffusers U-Net to original U-Net
|
# Diffusers U-Net to original U-Net
|
||||||
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
|
||||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||||
original_unet.load_state_dict(state_dict)
|
with init_empty_weights():
|
||||||
unet = original_unet
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))
|
||||||
print("U-Net converted to original U-Net")
|
print("U-Net converted to original U-Net")
|
||||||
|
|
||||||
logit_scale = None
|
logit_scale = None
|
||||||
|
|||||||
Reference in New Issue
Block a user