diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index a65cc1fd..7a026eba 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,7 +5,8 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_OUTPUT_BLOCKS = True +SKIP_INPUT_BLOCKS = True +SKIP_OUTPUT_BLOCKS = False SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True @@ -123,6 +124,8 @@ class LoRAControlNet(torch.nn.Module): block_name, index1, index2 = (name + "." + child_name).split(".")[:3] index1 = int(index1) if block_name == "input_blocks": + if SKIP_INPUT_BLOCKS: + continue depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) elif block_name == "middle_block": depth = 3