From 5851b2b7730db0efc497a9e4339924430b2c44c9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Jan 2023 18:43:54 +0900 Subject: [PATCH] Negative scale from prompt option --- gen_img_diffusers.py | 68 +++++++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index db81d48a..91cdc51e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -557,6 +557,7 @@ class PipelineLike(): width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, + negative_scale: float = None, strength: float = 0.8, # num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -675,6 +676,11 @@ class PipelineLike(): # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + # get unconditional embeddings for classifier free guidance if negative_prompt is None: negative_prompt = [""] * batch_size @@ -696,21 +702,21 @@ class PipelineLike(): **kwargs, ) - if args.negative_scale is not None: + if negative_scale is not None: _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=[""]*batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) + pipe=self, + prompt=[""]*batch_size, + uncond_prompt=[""]*batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) if do_classifier_free_guidance: - if args.negative_scale is None: + if negative_scale is None: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -841,26 +847,28 @@ class PipelineLike(): if accepts_eta: extra_step_kwargs["eta"] = eta + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance - repeats = 3 if args.negative_scale is not None else 2 - latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: - if args.negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond) + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond + noise_pred = noise_pred_uncond + guidance_scale * \ + (noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond) # perform clip guidance if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: - text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings) + text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[ + 1] if do_classifier_free_guidance else text_embeddings) if self.clip_guidance_scale > 0: noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, @@ -2163,12 +2171,12 @@ def main(args): # 1st stageのバッチを作成して呼び出す print("process 1st stage1") batch_1st = [] - for params1, (width, height, steps, scale, strength) in batch: + for params1, (width, height, steps, scale, negative_scale, strength) in batch: width_1st = int(width * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 height_1st = height_1st - height_1st % 32 - batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength))) + batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength))) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する @@ -2180,7 +2188,8 @@ def main(args): batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch = batch_2nd - (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0] + (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, + height, steps, scale, negative_scale, strength) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] @@ -2256,7 +2265,7 @@ def main(args): guide_images = guide_images[0] # generate - images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code, + images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] if highres_1st and not args.highres_fix_save_1st: return images @@ -2273,6 +2282,8 @@ def main(args): metadata.add_text("scale", str(scale)) if negative_prompt is not None: metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) @@ -2325,6 +2336,7 @@ def main(args): width = args.W height = args.H scale = args.scale + negative_scale = args.negative_scale steps = args.steps seeds = None strength = 0.8 if args.strength is None else args.strength @@ -2367,6 +2379,15 @@ def main(args): print(f"scale: {scale}") continue + m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == 'none': + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) @@ -2429,8 +2450,9 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image + # TODO named tupleか何かにする b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - (width, height, steps, scale, strength)) + (width, height, steps, scale, negative_scale, strength)) if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2527,7 +2549,7 @@ if __name__ == '__main__': parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") parser.add_argument("--negative_scale", type=float, default=None, - help="scaling negative prompt") + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") args = parser.parse_args() main(args)