From 3cdd62bbbf5026e942fbdc0a2f1b7357cbf381fd Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 23 Feb 2026 17:05:49 +0800 Subject: [PATCH] Change skip_image_resolution to tuple --- library/config_util.py | 4 ++-- library/train_util.py | 31 +++++++++++++++++++------------ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 7197662e..b31f9665 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,7 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None - skip_image_resolution: float = 0.0 + skip_image_resolution: Optional[Tuple[int, int]] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -245,7 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, - "skip_image_resolution": Any(float, int), + "skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 5db837d4..5b9d0615 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,7 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, - skip_image_resolution: float = 0.0, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() @@ -776,9 +776,8 @@ class BaseDataset(torch.utils.data.Dataset): return min_bucket_reso, max_bucket_reso def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: - orig_resolution = math.sqrt(image_size[0] * image_size[1]) # skip_image_resolution is exclusive - return self.skip_image_resolution < orig_resolution + return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1] def update_dataset_image_counts(self): for subset in self.subsets: @@ -799,7 +798,7 @@ class BaseDataset(torch.utils.data.Dataset): self.num_reg_images = num_reg_images def filter_registered_images_by_orig_resolution(self) -> int: - if self.skip_image_resolution == 0: + if self.skip_image_resolution is None: return 0 filtered_count = 0 @@ -2014,7 +2013,7 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], - skip_image_resolution: Optional[float] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2293,7 +2292,7 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], - skip_image_resolution: Optional[float] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2487,7 +2486,7 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, - skip_image_resolution: float = 0.0, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__( resolution, @@ -2592,7 +2591,7 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.skip_image_resolution != 0: + if self.skip_image_resolution is not None: if len(missing_imgs) > 0: logger.warning( f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" @@ -4735,10 +4734,10 @@ def add_dataset_arguments( ) parser.add_argument( "--skip_image_resolution", - type=float, - default=0.0, - help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling" - " / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します", + type=str, + default=None, + help="images not larger than this resolution will be skipped ('size' or 'width,height')" + " / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)", ) parser.add_argument( "--bucket_reso_steps", @@ -5553,6 +5552,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): len(args.resolution) == 2 ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + if args.skip_image_resolution is not None: + args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")]) + if len(args.skip_image_resolution) == 1: + args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0]) + assert ( + len(args.skip_image_resolution) == 2 + ), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}" + if args.face_crop_aug_range is not None: args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) assert (