diff --git a/library/train_util.py b/library/train_util.py index 34e477ed..031ce5a8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3886,8 +3886,16 @@ def pool_workaround( # input_ids: b*n,77 # find index for EOS token - eos_token_index = torch.where(input_ids == eos_token_id)[1] - eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]