Merge pull request #744 from kohya-ss/dev

fix sample gen failed in sdxl training
This commit is contained in:
Kohya S
2023-08-11 17:00:52 +09:00
committed by GitHub

View File

@@ -3896,6 +3896,7 @@ def pool_workaround(
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]