mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add skip input blocks to lora control net
This commit is contained in:
@@ -5,7 +5,8 @@ from networks.lora import LoRAModule, LoRANetwork
|
|||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
|
|
||||||
|
|
||||||
SKIP_OUTPUT_BLOCKS = True
|
SKIP_INPUT_BLOCKS = True
|
||||||
|
SKIP_OUTPUT_BLOCKS = False
|
||||||
SKIP_CONV2D = False
|
SKIP_CONV2D = False
|
||||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
||||||
ATTN1_ETC_ONLY = True
|
ATTN1_ETC_ONLY = True
|
||||||
@@ -123,6 +124,8 @@ class LoRAControlNet(torch.nn.Module):
|
|||||||
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
||||||
index1 = int(index1)
|
index1 = int(index1)
|
||||||
if block_name == "input_blocks":
|
if block_name == "input_blocks":
|
||||||
|
if SKIP_INPUT_BLOCKS:
|
||||||
|
continue
|
||||||
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
||||||
elif block_name == "middle_block":
|
elif block_name == "middle_block":
|
||||||
depth = 3
|
depth = 3
|
||||||
|
|||||||
Reference in New Issue
Block a user