mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix wrapper tokenizer not work for weighted prompt
This commit is contained in:
@@ -98,13 +98,16 @@ class WrapperTokenizer:
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for weighted prompt
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
|
||||
assert isinstance(text, str), f"input must be str: {text}"
|
||||
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
|
||||
|
||||
# find eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
|
||||
input_ids = input_ids[:, : eos_index + 1] # include eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero().max()
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
original_path = TOKENIZER_PATH
|
||||
|
||||
Reference in New Issue
Block a user