From 1cd95b2d8b3b994214681ab30cbdc74f9abc44ef Mon Sep 17 00:00:00 2001 From: woctordho Date: Thu, 19 Mar 2026 07:43:39 +0800 Subject: [PATCH] Add `skip_image_resolution` to deduplicate multi-resolution dataset (#2273) * Add min_orig_resolution and max_orig_resolution * Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution * Change skip_image_resolution to tuple * Move filtering to __init__ * Minor fix --- library/config_util.py | 6 ++- library/train_util.py | 84 ++++++++++++++++++++++++++++++++++++++---- train_network.py | 2 + 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 53727f25..b31f9665 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,6 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None + skip_image_resolution: Optional[Tuple[int, int]] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -118,7 +119,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -244,6 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, + "skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } # options handled by argparse but not handled by user config @@ -256,6 +258,7 @@ class ConfigSanitizer: ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", + "skip_image_resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { @@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + skip_image_resolution: {dataset.skip_image_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) diff --git a/library/train_util.py b/library/train_util.py index d8577b9d..b65f06b9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() @@ -727,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset): ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation + self.skip_image_resolution = skip_image_resolution + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -1915,8 +1918,15 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2034,6 +2044,22 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.skip_image_resolution is not None: + filtered_img_paths = [] + filtered_sizes = [] + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + for img_path, size in zip(img_paths, sizes): + if size is None: # no latents cache file, get image size by reading image file (slow) + size = self.get_image_size(img_path) + if size[0] * size[1] <= skip_image_area: + continue + filtered_img_paths.append(img_path) + filtered_sizes.append(size) + if len(filtered_img_paths) < len(img_paths): + logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + img_paths = filtered_img_paths + sizes = filtered_sizes + # We want to create a training and validation split. This should be improved in the future # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets @@ -2059,7 +2085,7 @@ class DreamBoothDataset(BaseDataset): logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: - captions = [meta["caption"] for meta in metas.values()] + captions = [metas[img_path]["caption"] for img_path in img_paths] missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] else: # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う @@ -2200,8 +2226,15 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう @@ -2297,6 +2330,7 @@ class FineTuningDataset(BaseDataset): tags_list = [] size_set_from_metadata = 0 size_set_from_cache_filename = 0 + num_filtered = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] caption = img_md.get("caption") @@ -2355,6 +2389,16 @@ class FineTuningDataset(BaseDataset): image_info.image_size = (w, h) size_set_from_cache_filename += 1 + if self.skip_image_resolution is not None: + size = image_info.image_size + if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow) + size = self.get_image_size(abs_path) + image_info.image_size = size + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + if size[0] * size[1] <= skip_image_area: + num_filtered += 1 + continue + self.register_image(image_info, subset) if size_set_from_cache_filename > 0: @@ -2363,6 +2407,8 @@ class FineTuningDataset(BaseDataset): ) if size_set_from_metadata > 0: logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") + if num_filtered > 0: + logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag @@ -2387,8 +2433,15 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) db_subsets = [] for subset in subsets: @@ -2440,6 +2493,7 @@ class ControlNetDataset(BaseDataset): validation_split, validation_seed, resize_interpolation, + skip_image_resolution, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2487,9 +2541,10 @@ class ControlNetDataset(BaseDataset): assert ( len(missing_imgs) == 0 ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - assert ( - len(extra_imgs) == 0 - ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + if len(extra_imgs) > 0: + logger.warning( + f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + ) self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4601,6 +4656,13 @@ def add_dataset_arguments( help="maximum resolution for buckets, must be divisible by bucket_reso_steps " " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", ) + parser.add_argument( + "--skip_image_resolution", + type=str, + default=None, + help="images not larger than this resolution will be skipped ('size' or 'width,height')" + " / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)", + ) parser.add_argument( "--bucket_reso_steps", type=int, @@ -5414,6 +5476,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): len(args.resolution) == 2 ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + if args.skip_image_resolution is not None: + args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")]) + if len(args.skip_image_resolution) == 1: + args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0]) + assert ( + len(args.skip_image_resolution) == 2 + ), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}" + if args.face_crop_aug_range is not None: args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) assert ( diff --git a/train_network.py b/train_network.py index 2f8797d2..2ee671e9 100644 --- a/train_network.py +++ b/train_network.py @@ -1085,6 +1085,7 @@ class NetworkTrainer: "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, + "skip_image_resolution": dataset.skip_image_resolution, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, "resize_interpolation": dataset.resize_interpolation, @@ -1191,6 +1192,7 @@ class NetworkTrainer: "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_skip_image_resolution": dataset.skip_image_resolution, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),