add negative_scale

This commit is contained in:
laksjdjf
2023-01-10 22:20:07 +09:00
committed by GitHub
parent f981dfd38a
commit 974674242e

View File

@@ -694,8 +694,21 @@ class PipelineLike():
**kwargs,
)
if args.negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=[""]*batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if args.negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
# CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0:
@@ -828,16 +841,20 @@ class PipelineLike():
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents
repeats = 3 if args.negative_scale is not None else 2
latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if args.negative_scale is None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond)
# perform clip guidance
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
@@ -2499,6 +2516,8 @@ if __name__ == '__main__':
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
parser.add_argument("--highres_fix_save_1st", action='store_true',
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
parser.add_argument("--negative_scale", type=float, default=None,
help="correctly cfg")
args = parser.parse_args()
main(args)