Merge pull request #720 from kohya-ss/dev

fix training textencoder in sdxl not working
This commit is contained in:
Kohya S
2023-08-05 21:24:24 +09:00
committed by GitHub

View File

@@ -3761,8 +3761,9 @@ def pool_workaround(
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
# apply projection
pooled_output = text_encoder.text_projection(pooled_output)
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output