add an option to disable controlnet in 2nd stage

This commit is contained in:
Kohya S
2023-06-15 20:56:12 +09:00
parent 9806b00f74
commit f0bb3ae825

View File

@@ -615,11 +615,15 @@ class PipelineLike:
# ControlNet
self.control_nets: List[ControlNetInfo] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
def set_enable_control_net(self, en: bool):
self.control_net_enabled = en
def replace_token(self, tokens, layer=None):
new_tokens = []
for token in tokens:
@@ -1112,7 +1116,7 @@ class PipelineLike:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
if self.control_nets:
if self.control_nets and self.control_net_enabled:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
@@ -2233,7 +2237,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
if not enumerating:
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
prompts = []
for _ in range(repeat_count):
@@ -2254,7 +2258,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
@@ -2933,6 +2937,8 @@ def main(args):
ext.num_sub_prompts,
)
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
@@ -2976,6 +2982,9 @@ def main(args):
batch_2nd.append(bd_2nd)
batch = batch_2nd
if args.highres_fix_disable_control_net:
pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする
# このバッチの情報を取り出す
(
return_latents,
@@ -3574,6 +3583,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--highres_fix_disable_control_net",
action="store_true",
help="disable ControlNet for highres fix / highres fixでControlNetを使わない",
)
parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"