diff --git a/library/config_util.py b/library/config_util.py index d41e6166..7197662e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,8 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None - min_orig_resolution: float = 0.0 - max_orig_resolution: float = float("inf") + skip_image_resolution: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -246,8 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, - "min_orig_resolution": Any(float, int), - "max_orig_resolution": Any(float, int), + "skip_image_resolution": Any(float, int), } # options handled by argparse but not handled by user config @@ -260,8 +258,7 @@ class ConfigSanitizer: ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", - "min_orig_resolution", - "max_orig_resolution", + "skip_image_resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { @@ -534,8 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} - min_orig_resolution: {dataset.min_orig_resolution} - max_orig_resolution: {dataset.max_orig_resolution} + skip_image_resolution: {dataset.skip_image_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) diff --git a/library/train_util.py b/library/train_util.py index e5645cea..5db837d4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,8 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, - min_orig_resolution: float = 0.0, - max_orig_resolution: float = float("inf"), + skip_image_resolution: float = 0.0, ) -> None: super().__init__() @@ -729,11 +728,7 @@ class BaseDataset(torch.utils.data.Dataset): ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation - assert ( - min_orig_resolution <= max_orig_resolution - ), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}" - self.min_orig_resolution = min_orig_resolution - self.max_orig_resolution = max_orig_resolution + self.skip_image_resolution = skip_image_resolution self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -782,11 +777,8 @@ class BaseDataset(torch.utils.data.Dataset): def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: orig_resolution = math.sqrt(image_size[0] * image_size[1]) - # min_orig_resolution is exclusive, max_orig_resolution is inclusive - return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution - - def has_orig_resolution_filter(self) -> bool: - return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf")) + # skip_image_resolution is exclusive + return self.skip_image_resolution < orig_resolution def update_dataset_image_counts(self): for subset in self.subsets: @@ -807,7 +799,7 @@ class BaseDataset(torch.utils.data.Dataset): self.num_reg_images = num_reg_images def filter_registered_images_by_orig_resolution(self) -> int: - if not self.has_orig_resolution_filter(): + if self.skip_image_resolution == 0: return 0 filtered_count = 0 @@ -2022,16 +2014,14 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], - min_orig_resolution: Optional[float] = None, - max_orig_resolution: Optional[float] = None, + skip_image_resolution: Optional[float] = None, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2303,16 +2293,14 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], - min_orig_resolution: Optional[float] = None, - max_orig_resolution: Optional[float] = None, + skip_image_resolution: Optional[float] = None, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) self.batch_size = batch_size @@ -2499,16 +2487,14 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, - min_orig_resolution: float = 0.0, - max_orig_resolution: float = float("inf"), + skip_image_resolution: float = 0.0, ) -> None: super().__init__( resolution, network_multiplier, debug_dataset, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) db_subsets = [] @@ -2561,8 +2547,7 @@ class ControlNetDataset(BaseDataset): validation_split, validation_seed, resize_interpolation, - min_orig_resolution, - max_orig_resolution, + skip_image_resolution, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2607,7 +2592,7 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.has_orig_resolution_filter(): + if self.skip_image_resolution != 0: if len(missing_imgs) > 0: logger.warning( f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" @@ -4749,18 +4734,11 @@ def add_dataset_arguments( " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", ) parser.add_argument( - "--min_orig_resolution", + "--skip_image_resolution", type=float, default=0.0, - help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling" - " / 画像の元解像度の下限(排他的)。リサイズ前のsqrt(width * height)で判定します", - ) - parser.add_argument( - "--max_orig_resolution", - type=float, - default=float("inf"), - help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling" - " / 画像の元解像度の上限(包含的)。リサイズ前のsqrt(width * height)で判定します", + help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling" + " / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します", ) parser.add_argument( "--bucket_reso_steps", diff --git a/train_network.py b/train_network.py index af48200d..2ee671e9 100644 --- a/train_network.py +++ b/train_network.py @@ -1085,8 +1085,7 @@ class NetworkTrainer: "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, - "min_orig_resolution": dataset.min_orig_resolution, - "max_orig_resolution": dataset.max_orig_resolution, + "skip_image_resolution": dataset.skip_image_resolution, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, "resize_interpolation": dataset.resize_interpolation, @@ -1193,8 +1192,7 @@ class NetworkTrainer: "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_min_orig_resolution": dataset.min_orig_resolution, - "ss_max_orig_resolution": dataset.max_orig_resolution, + "ss_skip_image_resolution": dataset.skip_image_resolution, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),