From 39bb319d4cac05d7da054ee726f86061e629574d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 29 Nov 2023 12:42:12 +0900 Subject: [PATCH] fix to work with cfg scale=1 --- sdxl_gen_img.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 78b90f8c..ab539984 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -504,7 +504,8 @@ class PipelineLike: uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -567,9 +568,11 @@ class PipelineLike: text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device)