mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use original unet for HF models, don't download TE
This commit is contained in:
@@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction
|
|||||||
|
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
||||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
|
||||||
|
|
||||||
DEFAULT_TOKEN_LENGTH = 75
|
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
SCHEDULER_LINEAR_END = 0.0120
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
@@ -2066,6 +2060,17 @@ def main(args):
|
|||||||
tokenizer = loading_pipe.tokenizer
|
tokenizer = loading_pipe.tokenizer
|
||||||
del loading_pipe
|
del loading_pipe
|
||||||
|
|
||||||
|
# Diffusers U-Net to original U-Net
|
||||||
|
original_unet = UNet2DConditionModel(
|
||||||
|
unet.config.sample_size,
|
||||||
|
unet.config.attention_head_dim,
|
||||||
|
unet.config.cross_attention_dim,
|
||||||
|
unet.config.use_linear_projection,
|
||||||
|
unet.config.upcast_attention,
|
||||||
|
)
|
||||||
|
original_unet.load_state_dict(unet.state_dict())
|
||||||
|
unet = original_unet
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
vae = model_util.load_vae(args.vae, dtype)
|
vae = model_util.load_vae(args.vae, dtype)
|
||||||
|
|||||||
@@ -933,10 +933,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
else:
|
else:
|
||||||
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||||
|
|
||||||
logging.set_verbosity_error() # don't show annoying warning
|
# logging.set_verbosity_error() # don't show annoying warning
|
||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||||
logging.set_verbosity_warning()
|
# logging.set_verbosity_warning()
|
||||||
|
# print(f"config: {text_model.config}")
|
||||||
|
cfg = CLIPTextConfig(
|
||||||
|
vocab_size=49408,
|
||||||
|
hidden_size=768,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
hidden_act="quick_gelu",
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
model_type="clip_text_model",
|
||||||
|
projection_dim=768,
|
||||||
|
torch_dtype="float32",
|
||||||
|
)
|
||||||
|
text_model = CLIPTextModel._from_config(cfg)
|
||||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
print("loading text encoder:", info)
|
print("loading text encoder:", info)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from torch.optim import Optimizer
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
import transformers
|
import transformers
|
||||||
import diffusers
|
|
||||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
@@ -52,6 +51,7 @@ from diffusers import (
|
|||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
)
|
)
|
||||||
|
from library.original_unet import UNet2DConditionModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import albumentations as albu
|
import albumentations as albu
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -2947,11 +2947,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
|||||||
print(
|
print(
|
||||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||||
)
|
)
|
||||||
|
raise ex
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = pipe.text_encoder
|
||||||
vae = pipe.vae
|
vae = pipe.vae
|
||||||
unet = pipe.unet
|
unet = pipe.unet
|
||||||
del pipe
|
del pipe
|
||||||
|
|
||||||
|
# Diffusers U-Net to original U-Net
|
||||||
|
# TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
|
||||||
|
# print(f"unet config: {unet.config}")
|
||||||
|
original_unet = UNet2DConditionModel(
|
||||||
|
unet.config.sample_size,
|
||||||
|
unet.config.attention_head_dim,
|
||||||
|
unet.config.cross_attention_dim,
|
||||||
|
unet.config.use_linear_projection,
|
||||||
|
unet.config.upcast_attention,
|
||||||
|
)
|
||||||
|
original_unet.load_state_dict(unet.state_dict())
|
||||||
|
unet = original_unet
|
||||||
|
print("U-Net converted to original U-Net")
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user