use original unet for HF models, don't download TE

This commit is contained in:
Kohya S
2023-06-14 22:26:05 +09:00
parent 44404fcd6d
commit 449ad7502c
3 changed files with 52 additions and 11 deletions

View File

@@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEFAULT_TOKEN_LENGTH = 75
# scheduler: # scheduler:
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_LINEAR_END = 0.0120
@@ -2066,6 +2060,17 @@ def main(args):
tokenizer = loading_pipe.tokenizer tokenizer = loading_pipe.tokenizer
del loading_pipe del loading_pipe
# Diffusers U-Net to original U-Net
original_unet = UNet2DConditionModel(
unet.config.sample_size,
unet.config.attention_head_dim,
unet.config.cross_attention_dim,
unet.config.use_linear_projection,
unet.config.upcast_attention,
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype) vae = model_util.load_vae(args.vae, dtype)

View File

@@ -933,10 +933,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
else: else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
logging.set_verbosity_error() # don't show annoying warning # logging.set_verbosity_error() # don't show annoying warning
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
logging.set_verbosity_warning() # logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
torch_dtype="float32",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint) info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info) print("loading text encoder:", info)

View File

@@ -36,7 +36,6 @@ from torch.optim import Optimizer
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import transformers import transformers
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
@@ -52,6 +51,7 @@ from diffusers import (
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
) )
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import albumentations as albu import albumentations as albu
import numpy as np import numpy as np
@@ -2947,11 +2947,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
print( print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
) )
raise ex
text_encoder = pipe.text_encoder text_encoder = pipe.text_encoder
vae = pipe.vae vae = pipe.vae
unet = pipe.unet unet = pipe.unet
del pipe del pipe
# Diffusers U-Net to original U-Net
# TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
# print(f"unet config: {unet.config}")
original_unet = UNet2DConditionModel(
unet.config.sample_size,
unet.config.attention_head_dim,
unet.config.cross_attention_dim,
unet.config.use_linear_projection,
unet.config.upcast_attention,
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
print("U-Net converted to original U-Net")
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype) vae = model_util.load_vae(args.vae, weight_dtype)