add lora controlnet train/gen temporarily

This commit is contained in:
Kohya S
2023-08-17 10:08:02 +09:00
parent 983698dd1b
commit 3f7235c36f
6 changed files with 3582 additions and 83 deletions

View File

@@ -1743,6 +1743,9 @@ class ControlNetDataset(BaseDataset):
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def __len__(self):
return self.dreambooth_dataset_delegate.__len__()
@@ -1775,9 +1778,14 @@ class ControlNetDataset(BaseDataset):
h, w = target_size_hw
cond_img = cond_img[ct : ct + h, cl : cl + w]
else:
assert (
cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# assert (
# cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = cv2.resize(
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride