fix max_token_length not works for sdxl

This commit is contained in:
Kohya S
2023-06-29 13:02:19 +09:00
parent 8521ab7990
commit d395bc0647
5 changed files with 39 additions and 33 deletions

View File

@@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids)[0]
# input_ids: b,n,77
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77