fix wrapper tokenizer not work for weighted prompt

This commit is contained in:
ykume
2023-07-09 13:33:16 +09:00
parent c1d62383c6
commit fe7ede5af3

View File

@@ -98,13 +98,16 @@ class WrapperTokenizer:
return SimpleNamespace(**{"input_ids": input_ids})
# for weighted prompt
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
assert isinstance(text, str), f"input must be str: {text}"
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
# find eos
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
input_ids = input_ids[:, : eos_index + 1] # include eos
eos_index = (input_ids == self.eos_token_id).nonzero().max()
input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
original_path = TOKENIZER_PATH