From 7081a0cf0f1ca1a543edf7cab10c4c7d497348ca Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 17 Mar 2024 18:09:15 +0900 Subject: [PATCH] extension of src image could be different than target image --- library/train_util.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 7fe5bc56..0f8cf9ee 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1863,7 +1863,7 @@ class ControlNetDataset(BaseDataset): # assert all conditioning data exists missing_imgs = [] - cond_imgs_with_img = set() + cond_imgs_with_pair = set() for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] subset = None @@ -1877,23 +1877,29 @@ class ControlNetDataset(BaseDataset): logger.warning(f"not directory: {subset.conditioning_data_dir}") continue - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive extra_imgs = [] for subset in subsets: conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS