From b4a3824ce4b98369129339c2502b5200e955b18a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Jul 2023 20:49:26 +0900 Subject: [PATCH] change tokenizer from open clip to transformers --- library/sdxl_train_util.py | 106 ++++++-------------------------- sdxl_gen_img.py | 10 +-- sdxl_train_textual_inversion.py | 12 ---- train_textual_inversion.py | 15 ++--- 4 files changed, 27 insertions(+), 116 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a1480777..4fc14bf2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,8 @@ from diffusers import StableDiffusionXLPipeline from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline -TOKENIZER_PATH = "openai/clip-vit-large-patch14" +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" DEFAULT_NOISE_OFFSET = 0.0357 @@ -108,101 +109,32 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info -class WrapperTokenizer: - # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする - # make open clip tokenizer compatible with HuggingFace tokenizer - def __init__(self): - open_clip_tokenizer = open_clip.tokenizer._tokenizer - self.model_max_length = 77 - self.bos_token_id = open_clip_tokenizer.all_special_ids[0] - self.eos_token_id = open_clip_tokenizer.all_special_ids[1] - self.pad_token_id = 0 # 結果から推定している assumption from result - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.tokenize(*args, **kwds) - - def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None): - if padding == "max_length": - # for training - assert max_length is not None - assert truncation == True - assert return_tensors == "pt" - input_ids = open_clip.tokenize(text, context_length=max_length) - return SimpleNamespace(**{"input_ids": input_ids}) - - # for weighted prompt - assert isinstance(text, str), f"input must be str: {text}" - - input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list - - # find eos - eos_index = (input_ids == self.eos_token_id).nonzero().max() - input_ids = input_ids[: eos_index + 1] # include eos - return SimpleNamespace(**{"input_ids": input_ids}) - - # for Textual Inversion - # わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI? - - def encode(self, text, add_special_tokens=False): - assert not add_special_tokens - input_ids = open_clip.tokenizer._tokenizer.encode(text) - return input_ids - - def add_tokens(self, new_tokens): - tokens_to_add = [] - for token in new_tokens: - token = token.lower() - if token + "" not in open_clip.tokenizer._tokenizer.encoder: - tokens_to_add.append(token) - - # open clipのtokenizerに直接追加する / add tokens to open clip tokenizer - for token in tokens_to_add: - open_clip.tokenizer._tokenizer.encoder[token + ""] = len(open_clip.tokenizer._tokenizer.encoder) - open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "" - open_clip.tokenizer._tokenizer.vocab_size += 1 - - # open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする - # めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる - # set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe - # this is very rough, so it will not work if the specification of open clip tokenizer changes - open_clip.tokenizer._tokenizer.cache[token] = token + "" - - return len(tokens_to_add) - - def convert_tokens_to_ids(self, tokens): - input_ids = [open_clip.tokenizer._tokenizer.encoder[token + ""] for token in tokens] - return input_ids - - def __len__(self): - return open_clip.tokenizer._tokenizer.vocab_size - - def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") - original_path = TOKENIZER_PATH - tokenizer1: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path) + original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] + tokeniers = [] + for original_path in original_paths: + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) - if tokenizer1 is None: - tokenizer1 = CLIPTokenizer.from_pretrained(original_path) + if tokenizer is None: + tokenizer = CLIPTokenizer.from_pretrained(original_path) - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer1.save_pretrained(local_tokenizer_path) + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: print(f"update token length: {args.max_token_length}") - # tokenizer2 is from open_clip - # TODO caching - tokenizer2 = WrapperTokenizer() - - return [tokenizer1, tokenizer2] + return tokeniers def get_hidden_states( diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 1e20595c..cc9e8a28 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1605,18 +1605,14 @@ def main(args): num_vectors_per_token = embeds1.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] - # remove non-alphabet characters to avoid splitting by tokenizer - # TODO make random alphabet string - token_string = "".join([c for c in token_string if c.isalpha()]) - - token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] # add new word to tokenizer, count is num_vectors_per_token num_added_tokens1 = tokenizer1.add_tokens(token_strings) num_added_tokens2 = tokenizer2.add_tokens(token_strings) assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( - f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}" + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" ) token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 9df37092..2616a22c 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -39,18 +39,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine tokenizer = sdxl_train_util.load_tokenizers(args) return tokenizer - def assert_token_string(self, token_string, tokenizers): - # tokenizer 1 is seems to be ok - - # count words for token string: regular expression from open_clip - pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE) - words = regex.findall(pat, token_string) - word_count = len(words) - assert word_count == 1, ( - f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters" - + f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください" - ) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids1 = batch["input_ids"] input_ids2 = batch["input_ids2"] diff --git a/train_textual_inversion.py b/train_textual_inversion.py index cbfd48ce..7be8ba80 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,6 +8,7 @@ from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from transformers import CLIPTokenizer from library import model_util import library.train_util as train_util @@ -92,7 +93,7 @@ class TextualInversionTrainer: tokenizer = train_util.load_tokenizer(args) return tokenizer - def assert_token_string(self, token_string, tokenizers): + def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): @@ -200,19 +201,13 @@ class TextualInversionTrainer: init_token_ids_list = [None] * len(tokenizers) # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token + # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される # add new word to tokenizer, count is num_vectors_per_token - - # token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される - # 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした - - # if token_string is hoge, "hoge", "hogea", "hogeb", ... are added - # originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used + # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added self.assert_token_string(args.token_string, tokenizers) - token_strings = [args.token_string] + [ - f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1) - ] + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] token_ids_list = [] token_embeds_list = [] for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):