From e73d103eca275224024b35dc0967c75ef66feff6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Aug 2023 16:58:52 +0900 Subject: [PATCH] fix sample gen failed in sdxl training --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 031ce5a8..e88a3dcf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3896,6 +3896,7 @@ def pool_workaround( # Use argmax to find the last index of the EOS token for each element in the batch eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]