mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix wrapper tokenizer not work for weighted prompt
This commit is contained in:
@@ -98,13 +98,16 @@ class WrapperTokenizer:
|
|||||||
return SimpleNamespace(**{"input_ids": input_ids})
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
# for weighted prompt
|
# for weighted prompt
|
||||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
|
assert isinstance(text, str), f"input must be str: {text}"
|
||||||
|
|
||||||
|
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
|
||||||
|
|
||||||
# find eos
|
# find eos
|
||||||
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
|
eos_index = (input_ids == self.eos_token_id).nonzero().max()
|
||||||
input_ids = input_ids[:, : eos_index + 1] # include eos
|
input_ids = input_ids[: eos_index + 1] # include eos
|
||||||
return SimpleNamespace(**{"input_ids": input_ids})
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizers(args: argparse.Namespace):
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
print("prepare tokenizers")
|
print("prepare tokenizers")
|
||||||
original_path = TOKENIZER_PATH
|
original_path = TOKENIZER_PATH
|
||||||
|
|||||||
Reference in New Issue
Block a user