diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 0189c632..a65cc1fd 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,7 +5,7 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_OUTPUT_BLOCKS = False +SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True @@ -286,7 +286,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 128, 32, 1) + control_net = LoRAControlNet(unet, 256, 64, 1) control_net.apply_to() control_net.to("cuda") diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index 489c7936..92469add 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -813,7 +813,7 @@ def setup_parser() -> argparse.ArgumentParser: if __name__ == "__main__": - # sdxl_original_unet.USE_REENTRANT = False + sdxl_original_unet.USE_REENTRANT = False parser = setup_parser()