change tokenizer from open clip to transformers

This commit is contained in:
Kohya S
2023-07-13 20:49:26 +09:00
parent 3bb80ebf20
commit b4a3824ce4
4 changed files with 27 additions and 116 deletions

View File

@@ -8,6 +8,7 @@ from tqdm import tqdm
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
import library.train_util as train_util
@@ -92,7 +93,7 @@ class TextualInversionTrainer:
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
@@ -200,19 +201,13 @@ class TextualInversionTrainer:
init_token_ids_list = [None] * len(tokenizers)
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
# token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
# add new word to tokenizer, count is num_vectors_per_token
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
]
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):