From 974674242e3f43ff5df7b89a450dd3bc877022b1 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Tue, 10 Jan 2023 22:20:07 +0900 Subject: [PATCH 1/2] add negative_scale --- gen_img_diffusers.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 208b1b70..ff3919c3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -694,8 +694,21 @@ class PipelineLike(): **kwargs, ) + if args.negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=[""]*batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + if args.negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -828,16 +841,20 @@ class PipelineLike(): for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents + repeats = 3 if args.negative_scale is not None else 2 + latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if args.negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond) # perform clip guidance if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: @@ -2499,6 +2516,8 @@ if __name__ == '__main__': help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") + parser.add_argument("--negative_scale", type=float, default=None, + help="correctly cfg") args = parser.parse_args() main(args) From 58d24ba254d0499039b8164630826939e5d33f2a Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Tue, 10 Jan 2023 22:24:20 +0900 Subject: [PATCH 2/2] Update gen_img_diffusers.py --- gen_img_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index ff3919c3..370f25b3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2517,7 +2517,7 @@ if __name__ == '__main__': parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") parser.add_argument("--negative_scale", type=float, default=None, - help="correctly cfg") + help="scaling negative prompt") args = parser.parse_args() main(args)