mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change tokenizer from open clip to transformers
This commit is contained in:
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -92,7 +93,7 @@ class TextualInversionTrainer:
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
||||
pass
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
@@ -200,19 +201,13 @@ class TextualInversionTrainer:
|
||||
init_token_ids_list = [None] * len(tokenizers)
|
||||
|
||||
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
|
||||
# token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
|
||||
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
|
||||
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
|
||||
|
||||
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
|
||||
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
|
||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
|
||||
token_strings = [args.token_string] + [
|
||||
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
|
||||
]
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
|
||||
Reference in New Issue
Block a user