From e5f9772a354055dad62758edd2b77665f4aefa5d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 5 Aug 2023 21:22:50 +0900 Subject: [PATCH] fix training textencoder in sdxl not working --- library/train_util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8a73445e..c5e903bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3761,8 +3761,9 @@ def pool_workaround( # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] - # apply projection - pooled_output = text_encoder.text_projection(pooled_output) + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) return pooled_output