mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Negative scale from prompt option
This commit is contained in:
@@ -557,6 +557,7 @@ class PipelineLike():
|
|||||||
width: int = 512,
|
width: int = 512,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
|
negative_scale: float = None,
|
||||||
strength: float = 0.8,
|
strength: float = 0.8,
|
||||||
# num_images_per_prompt: Optional[int] = 1,
|
# num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
@@ -675,6 +676,11 @@ class PipelineLike():
|
|||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
if not do_classifier_free_guidance and negative_scale is not None:
|
||||||
|
print(f"negative_scale is ignored if guidance scalle <= 1.0")
|
||||||
|
negative_scale = None
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = [""] * batch_size
|
negative_prompt = [""] * batch_size
|
||||||
@@ -696,21 +702,21 @@ class PipelineLike():
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.negative_scale is not None:
|
if negative_scale is not None:
|
||||||
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
|
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
|
||||||
pipe=self,
|
pipe=self,
|
||||||
prompt=prompt,
|
prompt=[""]*batch_size,
|
||||||
uncond_prompt=[""]*batch_size,
|
uncond_prompt=[""]*batch_size,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
clip_skip=self.clip_skip,
|
clip_skip=self.clip_skip,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
if args.negative_scale is None:
|
if negative_scale is None:
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
else:
|
else:
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
||||||
|
|
||||||
# CLIP guidanceで使用するembeddingsを取得する
|
# CLIP guidanceで使用するembeddingsを取得する
|
||||||
if self.clip_guidance_scale > 0:
|
if self.clip_guidance_scale > 0:
|
||||||
@@ -841,26 +847,28 @@ class PipelineLike():
|
|||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
repeats = 3 if args.negative_scale is not None else 2
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||||
latent_model_input = latents.repeat((repeats, 1, 1, 1)) if do_classifier_free_guidance else latents
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
if args.negative_scale is None:
|
if negative_scale is None:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
else:
|
else:
|
||||||
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(3)
|
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - args.negative_scale * (noise_pred_negative - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * \
|
||||||
|
(noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond)
|
||||||
|
|
||||||
# perform clip guidance
|
# perform clip guidance
|
||||||
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
|
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
|
||||||
text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings)
|
text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[
|
||||||
|
1] if do_classifier_free_guidance else text_embeddings)
|
||||||
|
|
||||||
if self.clip_guidance_scale > 0:
|
if self.clip_guidance_scale > 0:
|
||||||
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
|
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
|
||||||
@@ -2163,12 +2171,12 @@ def main(args):
|
|||||||
# 1st stageのバッチを作成して呼び出す
|
# 1st stageのバッチを作成して呼び出す
|
||||||
print("process 1st stage1")
|
print("process 1st stage1")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for params1, (width, height, steps, scale, strength) in batch:
|
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
||||||
width_1st = int(width * args.highres_fix_scale + .5)
|
width_1st = int(width * args.highres_fix_scale + .5)
|
||||||
height_1st = int(height * args.highres_fix_scale + .5)
|
height_1st = int(height * args.highres_fix_scale + .5)
|
||||||
width_1st = width_1st - width_1st % 32
|
width_1st = width_1st - width_1st % 32
|
||||||
height_1st = height_1st - height_1st % 32
|
height_1st = height_1st - height_1st % 32
|
||||||
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength)))
|
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
@@ -2180,7 +2188,8 @@ def main(args):
|
|||||||
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
||||||
batch = batch_2nd
|
batch = batch_2nd
|
||||||
|
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0]
|
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
||||||
|
height, steps, scale, negative_scale, strength) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
@@ -2256,7 +2265,7 @@ def main(args):
|
|||||||
guide_images = guide_images[0]
|
guide_images = guide_images[0]
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code,
|
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||||
if highres_1st and not args.highres_fix_save_1st:
|
if highres_1st and not args.highres_fix_save_1st:
|
||||||
return images
|
return images
|
||||||
@@ -2273,6 +2282,8 @@ def main(args):
|
|||||||
metadata.add_text("scale", str(scale))
|
metadata.add_text("scale", str(scale))
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
metadata.add_text("negative-prompt", negative_prompt)
|
metadata.add_text("negative-prompt", negative_prompt)
|
||||||
|
if negative_scale is not None:
|
||||||
|
metadata.add_text("negative-scale", str(negative_scale))
|
||||||
if clip_prompt is not None:
|
if clip_prompt is not None:
|
||||||
metadata.add_text("clip-prompt", clip_prompt)
|
metadata.add_text("clip-prompt", clip_prompt)
|
||||||
|
|
||||||
@@ -2325,6 +2336,7 @@ def main(args):
|
|||||||
width = args.W
|
width = args.W
|
||||||
height = args.H
|
height = args.H
|
||||||
scale = args.scale
|
scale = args.scale
|
||||||
|
negative_scale = args.negative_scale
|
||||||
steps = args.steps
|
steps = args.steps
|
||||||
seeds = None
|
seeds = None
|
||||||
strength = 0.8 if args.strength is None else args.strength
|
strength = 0.8 if args.strength is None else args.strength
|
||||||
@@ -2367,6 +2379,15 @@ def main(args):
|
|||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE)
|
||||||
|
if m: # negative scale
|
||||||
|
if m.group(1).lower() == 'none':
|
||||||
|
negative_scale = None
|
||||||
|
else:
|
||||||
|
negative_scale = float(m.group(1))
|
||||||
|
print(f"negative scale: {negative_scale}")
|
||||||
|
continue
|
||||||
|
|
||||||
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
|
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
|
||||||
if m: # strength
|
if m: # strength
|
||||||
strength = float(m.group(1))
|
strength = float(m.group(1))
|
||||||
@@ -2429,8 +2450,9 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
|
# TODO named tupleか何かにする
|
||||||
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
(width, height, steps, scale, strength))
|
(width, height, steps, scale, negative_scale, strength))
|
||||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
@@ -2527,7 +2549,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
||||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
||||||
parser.add_argument("--negative_scale", type=float, default=None,
|
parser.add_argument("--negative_scale", type=float, default=None,
|
||||||
help="scaling negative prompt")
|
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user