fix uncond/cond order

This commit is contained in:
Kohya S
2023-07-09 21:14:12 +09:00
parent 77ec70d145
commit c2ceb6de5f

View File

@@ -974,11 +974,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if do_classifier_free_guidance: if do_classifier_free_guidance:
text_embeddings = torch.cat(text_embeddings_list, dim=2) text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([text_embeddings, uncond_embeddings]).to(dtype) text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
cond_vector = torch.cat([text_pool, embs], dim=1) cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1) uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([cond_vector, uncond_vector]).to(dtype) vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
else: else:
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)