fix to work when input_ids has multiple EOS tokens

This commit is contained in:
Kohya S
2023-08-10 20:13:59 +09:00
parent 4e339bb101
commit daad50e384

View File

@@ -3886,8 +3886,16 @@ def pool_workaround(
# input_ids: b*n,77 # input_ids: b*n,77
# find index for EOS token # find index for EOS token
eos_token_index = torch.where(input_ids == eos_token_id)[1]
eos_token_index = eos_token_index.to(device=last_hidden_state.device) # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
# get hidden states for EOS token # get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]