From daad50e384e41703ac044d796658efc699340b05 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 10 Aug 2023 20:13:59 +0900 Subject: [PATCH] fix to work when input_ids has multiple EOS tokens --- library/train_util.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 34e477ed..031ce5a8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3886,8 +3886,16 @@ def pool_workaround( # input_ids: b*n,77 # find index for EOS token - eos_token_index = torch.where(input_ids == eos_token_id)[1] - eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]