change tokenizer from open clip to transformers

This commit is contained in:
Kohya S
2023-07-13 20:49:26 +09:00
parent 3bb80ebf20
commit b4a3824ce4
4 changed files with 27 additions and 116 deletions

View File

@@ -12,7 +12,8 @@ from diffusers import StableDiffusionXLPipeline
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
DEFAULT_NOISE_OFFSET = 0.0357
@@ -108,101 +109,32 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
class WrapperTokenizer:
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
# make open clip tokenizer compatible with HuggingFace tokenizer
def __init__(self):
open_clip_tokenizer = open_clip.tokenizer._tokenizer
self.model_max_length = 77
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
self.pad_token_id = 0 # 結果から推定している assumption from result
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds)
def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
if padding == "max_length":
# for training
assert max_length is not None
assert truncation == True
assert return_tensors == "pt"
input_ids = open_clip.tokenize(text, context_length=max_length)
return SimpleNamespace(**{"input_ids": input_ids})
# for weighted prompt
assert isinstance(text, str), f"input must be str: {text}"
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
# find eos
eos_index = (input_ids == self.eos_token_id).nonzero().max()
input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})
# for Textual Inversion
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
def encode(self, text, add_special_tokens=False):
assert not add_special_tokens
input_ids = open_clip.tokenizer._tokenizer.encode(text)
return input_ids
def add_tokens(self, new_tokens):
tokens_to_add = []
for token in new_tokens:
token = token.lower()
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
tokens_to_add.append(token)
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
for token in tokens_to_add:
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
open_clip.tokenizer._tokenizer.vocab_size += 1
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
# this is very rough, so it will not work if the specification of open clip tokenizer changes
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
return len(tokens_to_add)
def convert_tokens_to_ids(self, tokens):
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
return input_ids
def __len__(self):
return open_clip.tokenizer._tokenizer.vocab_size
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
original_path = TOKENIZER_PATH
tokenizer1: CLIPTokenizer = None
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
for original_path in original_paths:
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer1 is None:
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer1.save_pretrained(local_tokenizer_path)
tokenizer.save_pretrained(local_tokenizer_path)
tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
# tokenizer2 is from open_clip
# TODO caching
tokenizer2 = WrapperTokenizer()
return [tokenizer1, tokenizer2]
return tokeniers
def get_hidden_states(

View File

@@ -1605,18 +1605,14 @@ def main(args):
num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
f"tokenizer has same word to token string (filename): {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}"
)
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)

View File

@@ -39,18 +39,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
# tokenizer 1 is seems to be ok
# count words for token string: regular expression from open_clip
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
words = regex.findall(pat, token_string)
word_count = len(words)
assert word_count == 1, (
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]

View File

@@ -8,6 +8,7 @@ from tqdm import tqdm
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
import library.train_util as train_util
@@ -92,7 +93,7 @@ class TextualInversionTrainer:
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
@@ -200,19 +201,13 @@ class TextualInversionTrainer:
init_token_ids_list = [None] * len(tokenizers)
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
# token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
# add new word to tokenizer, count is num_vectors_per_token
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
]
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):