diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 9259389f..11b4db90 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -9,8 +9,8 @@ SKIP_INPUT_BLOCKS = False SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored -ATTN1_ETC_ONLY = True -TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks +ATTN1_ETC_ONLY = False # True +TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks class LoRAModuleControlNet(LoRAModule): @@ -77,11 +77,11 @@ class LoRAModuleControlNet(LoRAModule): # conditioning image cx = self.cond_emb - # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") - if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only + if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero + # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) @@ -303,7 +303,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 256, 64, 1) + control_net = LoRAControlNet(unet, 128, 64, 1) control_net.apply_to() control_net.to("cuda")