diff --git a/library/train_util.py b/library/train_util.py index 8a73445e..c5e903bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3761,8 +3761,9 @@ def pool_workaround( # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] - # apply projection - pooled_output = text_encoder.text_projection(pooled_output) + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) return pooled_output