mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix training textencoder in sdxl not working
This commit is contained in:
@@ -3761,8 +3761,9 @@ def pool_workaround(
|
|||||||
# get hidden states for EOS token
|
# get hidden states for EOS token
|
||||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||||
|
|
||||||
# apply projection
|
# apply projection: projection may be of different dtype than last_hidden_state
|
||||||
pooled_output = text_encoder.text_projection(pooled_output)
|
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
||||||
|
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
||||||
|
|
||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user