From f0bb3ae825efe6720f10301ee788072542b2e3ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:56:12 +0900 Subject: [PATCH] add an option to disable controlnet in 2nd stage --- gen_img_diffusers.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index acff1ea4..93a876ab 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -615,11 +615,15 @@ class PipelineLike: # ControlNet self.control_nets: List[ControlNetInfo] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + def replace_token(self, tokens, layer=None): new_tokens = [] for token in tokens: @@ -1112,7 +1116,7 @@ class PipelineLike: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - if self.control_nets: + if self.control_nets and self.control_net_enabled: if reginonal_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -2233,7 +2237,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): replacers.append(make_replacer_single(variants, count_range, separator)) # make each prompt - if not enumerating: + if not enumerating: # if not enumerating, repeat the prompt, replace each variant randomly prompts = [] for _ in range(repeat_count): @@ -2254,7 +2258,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for replecement in replecements: new_prompts.append(current.replace(found.group(0), replecement)) prompts = new_prompts - + for found, replacer in zip(founds, replacers): # make random selection for existing prompts if found.group(2) is None: @@ -2933,6 +2937,8 @@ def main(args): ext.num_sub_prompts, ) batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する @@ -2976,6 +2982,9 @@ def main(args): batch_2nd.append(bd_2nd) batch = batch_2nd + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + # このバッチの情報を取り出す ( return_latents, @@ -3574,6 +3583,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) parser.add_argument( "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"