mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change tokenizer from open clip to transformers
This commit is contained in:
@@ -12,7 +12,8 @@ from diffusers import StableDiffusionXLPipeline
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
@@ -108,101 +109,32 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
|
||||
def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
|
||||
if padding == "max_length":
|
||||
# for training
|
||||
assert max_length is not None
|
||||
assert truncation == True
|
||||
assert return_tensors == "pt"
|
||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for weighted prompt
|
||||
assert isinstance(text, str), f"input must be str: {text}"
|
||||
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
|
||||
|
||||
# find eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero().max()
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for Textual Inversion
|
||||
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
assert not add_special_tokens
|
||||
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||
return input_ids
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
token = token.lower()
|
||||
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||
tokens_to_add.append(token)
|
||||
|
||||
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||
for token in tokens_to_add:
|
||||
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||
|
||||
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return open_clip.tokenizer._tokenizer.vocab_size
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
original_path = TOKENIZER_PATH
|
||||
|
||||
tokenizer1: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
for original_path in original_paths:
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer1 is None:
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer1.save_pretrained(local_tokenizer_path)
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
tokeniers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
# tokenizer2 is from open_clip
|
||||
# TODO caching
|
||||
tokenizer2 = WrapperTokenizer()
|
||||
|
||||
return [tokenizer1, tokenizer2]
|
||||
return tokeniers
|
||||
|
||||
|
||||
def get_hidden_states(
|
||||
|
||||
Reference in New Issue
Block a user