diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 7a026eba..9259389f 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,11 +5,12 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_INPUT_BLOCKS = True -SKIP_OUTPUT_BLOCKS = False +SKIP_INPUT_BLOCKS = False +SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True +TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks class LoRAModuleControlNet(LoRAModule): @@ -48,15 +49,17 @@ class LoRAModuleControlNet(LoRAModule): self.depth = depth self.cond_emb = None - self.batch_cond_uncond_enabled = False + self.batch_cond_only = False + self.use_zeros_for_batch_uncond = False def set_cond_embs(self, cond_embs_4d, cond_embs_3d): cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d cond_emb = cond_embs[self.depth - 1] self.cond_emb = self.conditioning1(cond_emb) - def set_batch_cond_uncond_enabled(self, enabled): - self.batch_cond_uncond_enabled = enabled + def set_batch_cond_only(self, cond_only, zeros): + self.batch_cond_only = cond_only + self.use_zeros_for_batch_uncond = zeros def forward(self, x): if self.cond_emb is None: @@ -64,7 +67,7 @@ class LoRAModuleControlNet(LoRAModule): # LoRA lx = x - if self.batch_cond_uncond_enabled: + if self.batch_cond_only: lx = lx[1::2] # cond only lx = self.lora_down(lx) @@ -75,6 +78,10 @@ class LoRAModuleControlNet(LoRAModule): # conditioning image cx = self.cond_emb # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only + cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) + if self.use_zeros_for_batch_uncond: + cx[0::2] = 0.0 # uncond is zero cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) @@ -84,7 +91,7 @@ class LoRAModuleControlNet(LoRAModule): x = self.org_forward(x) - if self.batch_cond_uncond_enabled: + if self.batch_cond_only: x[1::2] += lx * self.multiplier * self.scale else: x += lx * self.multiplier * self.scale @@ -141,6 +148,13 @@ class LoRAControlNet(torch.nn.Module): lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + if TRANSFORMER_MAX_BLOCK_INDEX is not None: + p = lora_name.find("transformer_blocks") + if p >= 0: + tf_index = int(lora_name[p:].split("_")[2]) + if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: + continue + # skip time emb or clip emb if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue @@ -215,9 +229,9 @@ class LoRAControlNet(torch.nn.Module): for lora in self.unet_loras: lora.set_cond_embs(cond_embs_4d, cond_embs_3d) - def set_batch_cond_uncond_enabled(self, enabled): + def set_batch_cond_only(self, cond_only, zeros): for lora in self.unet_loras: - lora.set_batch_cond_uncond_enabled(enabled) + lora.set_batch_cond_only(cond_only, zeros) def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": @@ -294,11 +308,12 @@ if __name__ == "__main__": control_net.to("cuda") print(control_net) - input() # print number of parameters print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + input() + unet.set_use_memory_efficient_attention(True, False) unet.set_gradient_checkpointing(True) unet.train() # for gradient checkpointing diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py index e8b22ee1..b1cab576 100644 --- a/sdxl_gen_img_lora_ctrl_test.py +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -1571,7 +1571,7 @@ def main(args): control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device) - control_net.set_batch_cond_uncond_enabled(True) + control_net.set_batch_cond_only(False, False) control_nets.append(control_net) if args.opt_channels_last: