mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix bucketing doesn't work in controlnet training
This commit is contained in:
@@ -1770,12 +1770,16 @@ class ControlNetDataset(BaseDataset):
|
||||
cond_img = load_image(image_info.cond_img_path)
|
||||
|
||||
if self.dreambooth_dataset_delegate.enable_bucket:
|
||||
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
assert (
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
ct, cl = crop_top_left
|
||||
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
h, w = target_size_hw
|
||||
ct = (cond_img.shape[0] - h) // 2
|
||||
cl = (cond_img.shape[1] - w) // 2
|
||||
cond_img = cond_img[ct : ct + h, cl : cl + w]
|
||||
else:
|
||||
# assert (
|
||||
|
||||
@@ -120,6 +120,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
@@ -146,7 +147,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
|
||||
@@ -113,6 +113,8 @@ def train(args):
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
else:
|
||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user