From af3a55b21d8efa03307153c66d07842bca211391 Mon Sep 17 00:00:00 2001 From: woctordho Date: Fri, 20 Feb 2026 12:16:43 +0800 Subject: [PATCH 1/5] Add min_orig_resolution and max_orig_resolution --- library/config_util.py | 10 +- library/train_util.py | 209 ++++++++++++++++++++++++++++++++++++----- train_network.py | 4 + 3 files changed, 198 insertions(+), 25 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 53727f25..d41e6166 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,6 +108,8 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None + min_orig_resolution: float = 0.0 + max_orig_resolution: float = float("inf") @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -118,7 +120,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -244,6 +246,8 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, + "min_orig_resolution": Any(float, int), + "max_orig_resolution": Any(float, int), } # options handled by argparse but not handled by user config @@ -256,6 +260,8 @@ class ConfigSanitizer: ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", + "min_orig_resolution", + "max_orig_resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { @@ -528,6 +534,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + min_orig_resolution: {dataset.min_orig_resolution} + max_orig_resolution: {dataset.max_orig_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) diff --git a/library/train_util.py b/library/train_util.py index d8577b9d..e5645cea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,6 +687,8 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, + min_orig_resolution: float = 0.0, + max_orig_resolution: float = float("inf"), ) -> None: super().__init__() @@ -727,6 +729,12 @@ class BaseDataset(torch.utils.data.Dataset): ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation + assert ( + min_orig_resolution <= max_orig_resolution + ), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}" + self.min_orig_resolution = min_orig_resolution + self.max_orig_resolution = max_orig_resolution + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -772,6 +780,50 @@ class BaseDataset(torch.utils.data.Dataset): return min_bucket_reso, max_bucket_reso + def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: + orig_resolution = math.sqrt(image_size[0] * image_size[1]) + # min_orig_resolution is exclusive, max_orig_resolution is inclusive + return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution + + def has_orig_resolution_filter(self) -> bool: + return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf")) + + def update_dataset_image_counts(self): + for subset in self.subsets: + subset.img_count = 0 + + num_train_images = 0 + num_reg_images = 0 + for image_key, image_info in self.image_data.items(): + subset = self.image_to_subset[image_key] + subset.img_count += 1 + + if image_info.is_reg: + num_reg_images += image_info.num_repeats + else: + num_train_images += image_info.num_repeats + + self.num_train_images = num_train_images + self.num_reg_images = num_reg_images + + def filter_registered_images_by_orig_resolution(self) -> int: + if not self.has_orig_resolution_filter(): + return 0 + + filtered_count = 0 + for image_key, image_info in list(self.image_data.items()): + if self.check_orig_resolution(image_info.image_size): + continue + + del self.image_data[image_key] + del self.image_to_subset[image_key] + filtered_count += 1 + + if filtered_count > 0: + self.update_dataset_image_counts() + + return filtered_count + def set_seed(self, seed): self.seed = seed @@ -994,6 +1046,10 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + filtered_count = self.filter_registered_images_by_orig_resolution() + if filtered_count > 0: + logger.info(f"filtered {filtered_count} images by original resolution") + # # run in parallel # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) # with ThreadPoolExecutor(max_workers) as executor: @@ -1895,6 +1951,57 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + def register_regularization_images( + self, reg_infos: Sequence[Tuple[ImageInfo, DreamBoothSubset]], num_train_images: int + ) -> None: + if len(reg_infos) == 0 or num_train_images <= 0: + return + + n = 0 + first_loop = True + while n < num_train_images: + for info, subset in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False + + def rebalance_regularization_images(self): + if not self.is_training_dataset: + return + + reg_infos = [] + for image_key, image_info in list(self.image_data.items()): + if not image_info.is_reg: + continue + + reg_infos.append((image_info, self.image_to_subset[image_key])) + del self.image_data[image_key] + del self.image_to_subset[image_key] + + num_train_images = sum(info.num_repeats for info in self.image_data.values()) + if len(reg_infos) == 0: + return + + for info, subset in reg_infos: + info.num_repeats = subset.num_repeats + + self.register_regularization_images(reg_infos, num_train_images) + + def filter_registered_images_by_orig_resolution(self) -> int: + filtered_count = super().filter_registered_images_by_orig_resolution() + + if filtered_count > 0 and self.is_training_dataset: + self.rebalance_regularization_images() + self.update_dataset_image_counts() + + return filtered_count + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset @@ -1915,8 +2022,17 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], + min_orig_resolution: Optional[float] = None, + max_orig_resolution: Optional[float] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + min_orig_resolution, + max_orig_resolution, + ) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2059,7 +2175,7 @@ class DreamBoothDataset(BaseDataset): logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: - captions = [meta["caption"] for meta in metas.values()] + captions = [metas[img_path]["caption"] for img_path in img_paths] missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] else: # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う @@ -2166,20 +2282,7 @@ class DreamBoothDataset(BaseDataset): if num_reg_images == 0: logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info, subset in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 # rewrite registered info - n += 1 - if n >= num_train_images: - break - first_loop = False + self.register_regularization_images(reg_infos, num_train_images) self.num_reg_images = num_reg_images @@ -2200,8 +2303,17 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], + min_orig_resolution: Optional[float] = None, + max_orig_resolution: Optional[float] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + min_orig_resolution, + max_orig_resolution, + ) self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう @@ -2387,8 +2499,17 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, + min_orig_resolution: float = 0.0, + max_orig_resolution: float = float("inf"), ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + min_orig_resolution, + max_orig_resolution, + ) db_subsets = [] for subset in subsets: @@ -2440,6 +2561,8 @@ class ControlNetDataset(BaseDataset): validation_split, validation_seed, resize_interpolation, + min_orig_resolution, + max_orig_resolution, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2484,12 +2607,25 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert ( - len(missing_imgs) == 0 - ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - assert ( - len(extra_imgs) == 0 - ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + if self.has_orig_resolution_filter(): + if len(missing_imgs) > 0: + logger.warning( + f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" + + f" / 元画像解像度フィルタが有効なため、{len(missing_imgs)}枚の不足した制御用画像を無視します" + ) + if len(extra_imgs) > 0: + logger.warning( + f"ignore {len(extra_imgs)} extra conditioning images because original-resolution filtering is enabled" + + f" / 元画像解像度フィルタが有効なため、{len(extra_imgs)}枚の余分な制御用画像を無視します" + ) + # Later in `make_buckets` we assert `len(missing_imgs) == 0` but still ignore `extra_imgs` + else: + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -2498,8 +2634,19 @@ class ControlNetDataset(BaseDataset): def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() + + missing_imgs = [] + for info in self.dreambooth_dataset_delegate.image_data.values(): + if info.cond_img_path is None: + missing_imgs.append(os.path.splitext(os.path.basename(info.absolute_path))[0]) + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + self.num_train_images = self.dreambooth_dataset_delegate.num_train_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) @@ -4601,6 +4748,20 @@ def add_dataset_arguments( help="maximum resolution for buckets, must be divisible by bucket_reso_steps " " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", ) + parser.add_argument( + "--min_orig_resolution", + type=float, + default=0.0, + help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling" + " / 画像の元解像度の下限(排他的)。リサイズ前のsqrt(width * height)で判定します", + ) + parser.add_argument( + "--max_orig_resolution", + type=float, + default=float("inf"), + help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling" + " / 画像の元解像度の上限(包含的)。リサイズ前のsqrt(width * height)で判定します", + ) parser.add_argument( "--bucket_reso_steps", type=int, diff --git a/train_network.py b/train_network.py index 2f8797d2..af48200d 100644 --- a/train_network.py +++ b/train_network.py @@ -1085,6 +1085,8 @@ class NetworkTrainer: "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, + "min_orig_resolution": dataset.min_orig_resolution, + "max_orig_resolution": dataset.max_orig_resolution, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, "resize_interpolation": dataset.resize_interpolation, @@ -1191,6 +1193,8 @@ class NetworkTrainer: "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_min_orig_resolution": dataset.min_orig_resolution, + "ss_max_orig_resolution": dataset.max_orig_resolution, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), From 47afa8b2bd4e49465eb4fedcf7e8dd26a0d63401 Mon Sep 17 00:00:00 2001 From: woctordho Date: Fri, 20 Feb 2026 22:59:40 +0800 Subject: [PATCH 2/5] Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution --- library/config_util.py | 12 ++++------ library/train_util.py | 54 +++++++++++++----------------------------- train_network.py | 6 ++--- 3 files changed, 22 insertions(+), 50 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index d41e6166..7197662e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,8 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None - min_orig_resolution: float = 0.0 - max_orig_resolution: float = float("inf") + skip_image_resolution: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -246,8 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, - "min_orig_resolution": Any(float, int), - "max_orig_resolution": Any(float, int), + "skip_image_resolution": Any(float, int), } # options handled by argparse but not handled by user config @@ -260,8 +258,7 @@ class ConfigSanitizer: ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", - "min_orig_resolution", - "max_orig_resolution", + "skip_image_resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { @@ -534,8 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} - min_orig_resolution: {dataset.min_orig_resolution} - max_orig_resolution: {dataset.max_orig_resolution} + skip_image_resolution: {dataset.skip_image_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) diff --git a/library/train_util.py b/library/train_util.py index e5645cea..5db837d4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,8 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, - min_orig_resolution: float = 0.0, - max_orig_resolution: float = float("inf"), + skip_image_resolution: float = 0.0, ) -> None: super().__init__() @@ -729,11 +728,7 @@ class BaseDataset(torch.utils.data.Dataset): ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation - assert ( - min_orig_resolution <= max_orig_resolution - ), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}" - self.min_orig_resolution = min_orig_resolution - self.max_orig_resolution = max_orig_resolution + self.skip_image_resolution = skip_image_resolution self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -782,11 +777,8 @@ class BaseDataset(torch.utils.data.Dataset): def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: orig_resolution = math.sqrt(image_size[0] * image_size[1]) - # min_orig_resolution is exclusive, max_orig_resolution is inclusive - return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution - - def has_orig_resolution_filter(self) -> bool: - return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf")) + # skip_image_resolution is exclusive + return self.skip_image_resolution < orig_resolution def update_dataset_image_counts(self): for subset in self.subsets: @@ -807,7 +799,7 @@ class BaseDataset(torch.utils.data.Dataset): self.num_reg_images = num_reg_images def filter_registered_images_by_orig_resolution(self) -> int: - if not self.has_orig_resolution_filter(): + if self.skip_image_resolution == 0: return 0 filtered_count = 0 @@ -2022,16 +2014,14 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], - min_orig_resolution: Optional[float] = None, - max_orig_resolution: Optional[float] = None, + skip_image_resolution: Optional[float] = None, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2303,16 +2293,14 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], - min_orig_resolution: Optional[float] = None, - max_orig_resolution: Optional[float] = None, + skip_image_resolution: Optional[float] = None, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) self.batch_size = batch_size @@ -2499,16 +2487,14 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, - min_orig_resolution: float = 0.0, - max_orig_resolution: float = float("inf"), + skip_image_resolution: float = 0.0, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) db_subsets = [] @@ -2561,8 +2547,7 @@ class ControlNetDataset(BaseDataset): validation_split, validation_seed, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2607,7 +2592,7 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.has_orig_resolution_filter(): + if self.skip_image_resolution != 0: if len(missing_imgs) > 0: logger.warning( f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" @@ -4749,18 +4734,11 @@ def add_dataset_arguments( " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", ) parser.add_argument( - "--min_orig_resolution", + "--skip_image_resolution", type=float, default=0.0, - help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling" - " / 画像の元解像度の下限(排他的)。リサイズ前のsqrt(width * height)で判定します", - ) - parser.add_argument( - "--max_orig_resolution", - type=float, - default=float("inf"), - help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling" - " / 画像の元解像度の上限(包含的)。リサイズ前のsqrt(width * height)で判定します", + help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling" + " / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します", ) parser.add_argument( "--bucket_reso_steps", diff --git a/train_network.py b/train_network.py index af48200d..2ee671e9 100644 --- a/train_network.py +++ b/train_network.py @@ -1085,8 +1085,7 @@ class NetworkTrainer: "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, - "min_orig_resolution": dataset.min_orig_resolution, - "max_orig_resolution": dataset.max_orig_resolution, + "skip_image_resolution": dataset.skip_image_resolution, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, "resize_interpolation": dataset.resize_interpolation, @@ -1193,8 +1192,7 @@ class NetworkTrainer: "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_min_orig_resolution": dataset.min_orig_resolution, - "ss_max_orig_resolution": dataset.max_orig_resolution, + "ss_skip_image_resolution": dataset.skip_image_resolution, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), From 3cdd62bbbf5026e942fbdc0a2f1b7357cbf381fd Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 23 Feb 2026 17:05:49 +0800 Subject: [PATCH 3/5] Change skip_image_resolution to tuple --- library/config_util.py | 4 ++-- library/train_util.py | 31 +++++++++++++++++++------------ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 7197662e..b31f9665 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,7 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None - skip_image_resolution: float = 0.0 + skip_image_resolution: Optional[Tuple[int, int]] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -245,7 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, - "skip_image_resolution": Any(float, int), + "skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 5db837d4..5b9d0615 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,7 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, - skip_image_resolution: float = 0.0, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() @@ -776,9 +776,8 @@ class BaseDataset(torch.utils.data.Dataset): return min_bucket_reso, max_bucket_reso def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: - orig_resolution = math.sqrt(image_size[0] * image_size[1]) # skip_image_resolution is exclusive - return self.skip_image_resolution < orig_resolution + return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1] def update_dataset_image_counts(self): for subset in self.subsets: @@ -799,7 +798,7 @@ class BaseDataset(torch.utils.data.Dataset): self.num_reg_images = num_reg_images def filter_registered_images_by_orig_resolution(self) -> int: - if self.skip_image_resolution == 0: + if self.skip_image_resolution is None: return 0 filtered_count = 0 @@ -2014,7 +2013,7 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], - skip_image_resolution: Optional[float] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2293,7 +2292,7 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], - skip_image_resolution: Optional[float] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2487,7 +2486,7 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, - skip_image_resolution: float = 0.0, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2592,7 +2591,7 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.skip_image_resolution != 0: + if self.skip_image_resolution is not None: if len(missing_imgs) > 0: logger.warning( f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" @@ -4735,10 +4734,10 @@ def add_dataset_arguments( ) parser.add_argument( "--skip_image_resolution", - type=float, - default=0.0, - help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling" - " / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します", + type=str, + default=None, + help="images not larger than this resolution will be skipped ('size' or 'width,height')" + " / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)", ) parser.add_argument( "--bucket_reso_steps", @@ -5553,6 +5552,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): len(args.resolution) == 2 ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + if args.skip_image_resolution is not None: + args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")]) + if len(args.skip_image_resolution) == 1: + args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0]) + assert ( + len(args.skip_image_resolution) == 2 + ), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}" + if args.face_crop_aug_range is not None: args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) assert ( From 5af418025de26ce028beae6e9d191ce591d77fb8 Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 23 Feb 2026 17:28:04 +0800 Subject: [PATCH 4/5] Move filtering to __init__ --- library/train_util.py | 175 ++++++++++++------------------------------ 1 file changed, 50 insertions(+), 125 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5b9d0615..c96e0b85 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -775,46 +775,6 @@ class BaseDataset(torch.utils.data.Dataset): return min_bucket_reso, max_bucket_reso - def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: - # skip_image_resolution is exclusive - return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1] - - def update_dataset_image_counts(self): - for subset in self.subsets: - subset.img_count = 0 - - num_train_images = 0 - num_reg_images = 0 - for image_key, image_info in self.image_data.items(): - subset = self.image_to_subset[image_key] - subset.img_count += 1 - - if image_info.is_reg: - num_reg_images += image_info.num_repeats - else: - num_train_images += image_info.num_repeats - - self.num_train_images = num_train_images - self.num_reg_images = num_reg_images - - def filter_registered_images_by_orig_resolution(self) -> int: - if self.skip_image_resolution is None: - return 0 - - filtered_count = 0 - for image_key, image_info in list(self.image_data.items()): - if self.check_orig_resolution(image_info.image_size): - continue - - del self.image_data[image_key] - del self.image_to_subset[image_key] - filtered_count += 1 - - if filtered_count > 0: - self.update_dataset_image_counts() - - return filtered_count - def set_seed(self, seed): self.seed = seed @@ -1037,10 +997,6 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) - filtered_count = self.filter_registered_images_by_orig_resolution() - if filtered_count > 0: - logger.info(f"filtered {filtered_count} images by original resolution") - # # run in parallel # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) # with ThreadPoolExecutor(max_workers) as executor: @@ -1942,57 +1898,6 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - def register_regularization_images( - self, reg_infos: Sequence[Tuple[ImageInfo, DreamBoothSubset]], num_train_images: int - ) -> None: - if len(reg_infos) == 0 or num_train_images <= 0: - return - - n = 0 - first_loop = True - while n < num_train_images: - for info, subset in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False - - def rebalance_regularization_images(self): - if not self.is_training_dataset: - return - - reg_infos = [] - for image_key, image_info in list(self.image_data.items()): - if not image_info.is_reg: - continue - - reg_infos.append((image_info, self.image_to_subset[image_key])) - del self.image_data[image_key] - del self.image_to_subset[image_key] - - num_train_images = sum(info.num_repeats for info in self.image_data.values()) - if len(reg_infos) == 0: - return - - for info, subset in reg_infos: - info.num_repeats = subset.num_repeats - - self.register_regularization_images(reg_infos, num_train_images) - - def filter_registered_images_by_orig_resolution(self) -> int: - filtered_count = super().filter_registered_images_by_orig_resolution() - - if filtered_count > 0 and self.is_training_dataset: - self.rebalance_regularization_images() - self.update_dataset_image_counts() - - return filtered_count - # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset @@ -2139,6 +2044,22 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.skip_image_resolution is not None: + filtered_img_paths = [] + filtered_sizes = [] + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + for img_path, size in zip(img_paths, sizes): + if size is None: # no latents cache file, get image size by reading image file (slow) + size = self.get_image_size(img_path) + if size[0] * size[1] <= skip_image_area: + continue + filtered_img_paths.append(img_path) + filtered_sizes.append(size) + if len(filtered_img_paths) < len(img_paths): + logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + img_paths = filtered_img_paths + sizes = filtered_sizes + # We want to create a training and validation split. This should be improved in the future # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets @@ -2271,7 +2192,20 @@ class DreamBoothDataset(BaseDataset): if num_reg_images == 0: logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: - self.register_regularization_images(reg_infos, num_train_images) + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info, subset in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 # rewrite registered info + n += 1 + if n >= num_train_images: + break + first_loop = False self.num_reg_images = num_reg_images @@ -2396,6 +2330,7 @@ class FineTuningDataset(BaseDataset): tags_list = [] size_set_from_metadata = 0 size_set_from_cache_filename = 0 + num_filtered = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] caption = img_md.get("caption") @@ -2454,6 +2389,16 @@ class FineTuningDataset(BaseDataset): image_info.image_size = (w, h) size_set_from_cache_filename += 1 + if self.skip_image_resolution is not None: + size = image_info.image_size + if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow) + size = self.get_image_size(abs_path) + image_info.image_size = size + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + if size[0] * size[1] <= skip_image_area: + num_filtered += 1 + continue + self.register_image(image_info, subset) if size_set_from_cache_filename > 0: @@ -2462,6 +2407,8 @@ class FineTuningDataset(BaseDataset): ) if size_set_from_metadata > 0: logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") + if num_filtered > 0: + logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag @@ -2591,25 +2538,13 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.skip_image_resolution is not None: - if len(missing_imgs) > 0: - logger.warning( - f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" - + f" / 元画像解像度フィルタが有効なため、{len(missing_imgs)}枚の不足した制御用画像を無視します" - ) - if len(extra_imgs) > 0: - logger.warning( - f"ignore {len(extra_imgs)} extra conditioning images because original-resolution filtering is enabled" - + f" / 元画像解像度フィルタが有効なため、{len(extra_imgs)}枚の余分な制御用画像を無視します" - ) - # Later in `make_buckets` we assert `len(missing_imgs) == 0` but still ignore `extra_imgs` - else: - assert ( - len(missing_imgs) == 0 - ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - assert ( - len(extra_imgs) == 0 - ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + if len(extra_imgs) > 0: + logger.warning( + f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + ) self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -2619,18 +2554,8 @@ class ControlNetDataset(BaseDataset): def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() - missing_imgs = [] - for info in self.dreambooth_dataset_delegate.image_data.values(): - if info.cond_img_path is None: - missing_imgs.append(os.path.splitext(os.path.basename(info.absolute_path))[0]) - assert ( - len(missing_imgs) == 0 - ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) From 1e4f55cc53b551fc1d2946a60401fcac378cbeb1 Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 23 Feb 2026 17:36:46 +0800 Subject: [PATCH 5/5] Minor fix --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index c96e0b85..b65f06b9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2553,7 +2553,6 @@ class ControlNetDataset(BaseDataset): def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() - self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices