mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix error in generation
This commit is contained in:
@@ -9,8 +9,8 @@ SKIP_INPUT_BLOCKS = False
|
||||
SKIP_OUTPUT_BLOCKS = True
|
||||
SKIP_CONV2D = False
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored
|
||||
ATTN1_ETC_ONLY = True
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks
|
||||
|
||||
|
||||
class LoRAModuleControlNet(LoRAModule):
|
||||
@@ -77,11 +77,11 @@ class LoRAModuleControlNet(LoRAModule):
|
||||
|
||||
# conditioning image
|
||||
cx = self.cond_emb
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only
|
||||
if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
|
||||
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.conditioning2(cx)
|
||||
@@ -303,7 +303,7 @@ if __name__ == "__main__":
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create LoRA controlnet")
|
||||
control_net = LoRAControlNet(unet, 256, 64, 1)
|
||||
control_net = LoRAControlNet(unet, 128, 64, 1)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user