From fe7ede5af33e9d4fecc22ee454f06e28c4bcff74 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 9 Jul 2023 13:33:16 +0900 Subject: [PATCH] fix wrapper tokenizer not work for weighted prompt --- library/sdxl_train_util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f4b2c173..c67a7043 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -98,13 +98,16 @@ class WrapperTokenizer: return SimpleNamespace(**{"input_ids": input_ids}) # for weighted prompt - input_ids = open_clip.tokenize(text, context_length=self.model_max_length) + assert isinstance(text, str), f"input must be str: {text}" + + input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list # find eos - eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch - input_ids = input_ids[:, : eos_index + 1] # include eos + eos_index = (input_ids == self.eos_token_id).nonzero().max() + input_ids = input_ids[: eos_index + 1] # include eos return SimpleNamespace(**{"input_ids": input_ids}) + def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") original_path = TOKENIZER_PATH