fix error in generation

This commit is contained in:
ykume
2023-08-17 18:31:29 +09:00
parent 5fa473d5f3
commit 809fca0be9

View File

@@ -9,8 +9,8 @@ SKIP_INPUT_BLOCKS = False
SKIP_OUTPUT_BLOCKS = True SKIP_OUTPUT_BLOCKS = True
SKIP_CONV2D = False SKIP_CONV2D = False
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
ATTN1_ETC_ONLY = True ATTN1_ETC_ONLY = False # True
TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks
class LoRAModuleControlNet(LoRAModule): class LoRAModuleControlNet(LoRAModule):
@@ -77,11 +77,11 @@ class LoRAModuleControlNet(LoRAModule):
# conditioning image # conditioning image
cx = self.cond_emb cx = self.cond_emb
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only
if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond: if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
cx = self.conditioning2(cx) cx = self.conditioning2(cx)
@@ -303,7 +303,7 @@ if __name__ == "__main__":
unet.to("cuda").to(torch.float16) unet.to("cuda").to(torch.float16)
print("create LoRA controlnet") print("create LoRA controlnet")
control_net = LoRAControlNet(unet, 256, 64, 1) control_net = LoRAControlNet(unet, 128, 64, 1)
control_net.apply_to() control_net.apply_to()
control_net.to("cuda") control_net.to("cuda")