mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix bucketing doesn't work in controlnet training
This commit is contained in:
@@ -120,6 +120,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
@@ -146,7 +147,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
|
||||
Reference in New Issue
Block a user