diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index f5133407..db81d48a 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -696,8 +696,21 @@ class PipelineLike(): **kwargs, ) + if args.negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=[""]*batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + if args.negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -830,16 +843,20 @@ class PipelineLike(): for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents + repeats = 3 if args.negative_scale is not None else 2 + latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if args.negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond) # perform clip guidance if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: @@ -2509,6 +2526,8 @@ if __name__ == '__main__': help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") + parser.add_argument("--negative_scale", type=float, default=None, + help="scaling negative prompt") args = parser.parse_args() main(args)