diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f4b2c173..c67a7043 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -98,13 +98,16 @@ class WrapperTokenizer: return SimpleNamespace(**{"input_ids": input_ids}) # for weighted prompt - input_ids = open_clip.tokenize(text, context_length=self.model_max_length) + assert isinstance(text, str), f"input must be str: {text}" + + input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list # find eos - eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch - input_ids = input_ids[:, : eos_index + 1] # include eos + eos_index = (input_ids == self.eos_token_id).nonzero().max() + input_ids = input_ids[: eos_index + 1] # include eos return SimpleNamespace(**{"input_ids": input_ids}) + def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") original_path = TOKENIZER_PATH