diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index d2f59a33..6578b9a8 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -511,7 +511,7 @@ class PipelineLike: emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype) c_vector = torch.cat([text_pool, c_vector], dim=1)