Negative scale from prompt option

This commit is contained in:
Kohya S
2023-01-14 18:43:54 +09:00
parent e4695e9359
commit 5851b2b773

View File

@@ -557,6 +557,7 @@ class PipelineLike():
width: int = 512, width: int = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_scale: float = None,
strength: float = 0.8, strength: float = 0.8,
# num_images_per_prompt: Optional[int] = 1, # num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
@@ -675,6 +676,11 @@ class PipelineLike():
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if negative_prompt is None: if negative_prompt is None:
negative_prompt = [""] * batch_size negative_prompt = [""] * batch_size
@@ -696,10 +702,10 @@ class PipelineLike():
**kwargs, **kwargs,
) )
if args.negative_scale is not None: if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings( _, real_uncond_embeddings, _ = get_weighted_text_embeddings(
pipe=self, pipe=self,
prompt=prompt, prompt=[""]*batch_size,
uncond_prompt=[""]*batch_size, uncond_prompt=[""]*batch_size,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
@@ -707,7 +713,7 @@ class PipelineLike():
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
if args.negative_scale is None: if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else: else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
@@ -841,26 +847,28 @@ class PipelineLike():
if accepts_eta: if accepts_eta:
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
repeats = 3 if args.negative_scale is not None else 2 latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
if args.negative_scale is None: if negative_scale is None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else: else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3) noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond)
# perform clip guidance # perform clip guidance
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings) text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[
1] if do_classifier_free_guidance else text_embeddings)
if self.clip_guidance_scale > 0: if self.clip_guidance_scale > 0:
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
@@ -2163,12 +2171,12 @@ def main(args):
# 1st stageのバッチを作成して呼び出す # 1st stageのバッチを作成して呼び出す
print("process 1st stage1") print("process 1st stage1")
batch_1st = [] batch_1st = []
for params1, (width, height, steps, scale, strength) in batch: for params1, (width, height, steps, scale, negative_scale, strength) in batch:
width_1st = int(width * args.highres_fix_scale + .5) width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32 width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32 height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength))) batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
@@ -2180,7 +2188,8 @@ def main(args):
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
batch = batch_2nd batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0] (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
height, steps, scale, negative_scale, strength) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = [] prompts = []
@@ -2256,7 +2265,7 @@ def main(args):
guide_images = guide_images[0] guide_images = guide_images[0]
# generate # generate
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code, images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st: if highres_1st and not args.highres_fix_save_1st:
return images return images
@@ -2273,6 +2282,8 @@ def main(args):
metadata.add_text("scale", str(scale)) metadata.add_text("scale", str(scale))
if negative_prompt is not None: if negative_prompt is not None:
metadata.add_text("negative-prompt", negative_prompt) metadata.add_text("negative-prompt", negative_prompt)
if negative_scale is not None:
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None: if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("clip-prompt", clip_prompt)
@@ -2325,6 +2336,7 @@ def main(args):
width = args.W width = args.W
height = args.H height = args.H
scale = args.scale scale = args.scale
negative_scale = args.negative_scale
steps = args.steps steps = args.steps
seeds = None seeds = None
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
@@ -2367,6 +2379,15 @@ def main(args):
print(f"scale: {scale}") print(f"scale: {scale}")
continue continue
m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == 'none':
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
if m: # strength if m: # strength
strength = float(m.group(1)) strength = float(m.group(1))
@@ -2429,8 +2450,9 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
# TODO named tupleか何かにする
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
(width, height, steps, scale, strength)) (width, height, steps, scale, negative_scale, strength))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
@@ -2527,7 +2549,7 @@ if __name__ == '__main__':
parser.add_argument("--highres_fix_save_1st", action='store_true', parser.add_argument("--highres_fix_save_1st", action='store_true',
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
parser.add_argument("--negative_scale", type=float, default=None, parser.add_argument("--negative_scale", type=float, default=None,
help="scaling negative prompt") help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)