mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work when input_ids has multiple EOS tokens
This commit is contained in:
@@ -3886,8 +3886,16 @@ def pool_workaround(
|
|||||||
|
|
||||||
# input_ids: b*n,77
|
# input_ids: b*n,77
|
||||||
# find index for EOS token
|
# find index for EOS token
|
||||||
eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
|
||||||
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
|
||||||
|
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||||
|
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||||
|
|
||||||
|
# Create a mask where the EOS tokens are
|
||||||
|
eos_token_mask = (input_ids == eos_token_id).int()
|
||||||
|
|
||||||
|
# Use argmax to find the last index of the EOS token for each element in the batch
|
||||||
|
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
|
||||||
|
|
||||||
# get hidden states for EOS token
|
# get hidden states for EOS token
|
||||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||||
|
|||||||
Reference in New Issue
Block a user