diff --git a/library/train_util.py b/library/train_util.py index 631f1cb7..ff1b4a33 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1770,12 +1770,16 @@ class ControlNetDataset(BaseDataset): cond_img = load_image(image_info.cond_img_path) if self.dreambooth_dataset_delegate.enable_bucket: - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - ct, cl = crop_top_left + cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + # TODO support random crop + # 現在サポートしているcropはrandomではなく中央のみ h, w = target_size_hw + ct = (cond_img.shape[0] - h) // 2 + cl = (cond_img.shape[1] - w) // 2 cond_img = cond_img[ct : ct + h, cl : cl + w] else: # assert ( diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 36e36071..3140919c 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -120,6 +120,7 @@ class LLLiteModule(torch.nn.Module): / call the model inside, so if necessary, surround it with torch.no_grad() """ # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance + # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -146,7 +147,7 @@ class LLLiteModule(torch.nn.Module): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6e5c4232..09cf1643 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -113,6 +113,8 @@ def train(args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert (