fix max mul embeds doesn't work. closes #656

This commit is contained in:
Kohya S
2023-07-23 15:18:27 +09:00
parent c1d5c24bc7
commit 7ae0cde754

View File

@@ -958,7 +958,7 @@ def get_unweighted_text_embeddings(
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-2]
if pool is None:
pool = enc_out["text_embeds"] # use 1st chunk
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
if no_boseos_middle:
if i == 0: