Change skip_image_resolution to tuple

This commit is contained in:
woctordho
2026-02-23 17:05:49 +08:00
parent 47afa8b2bd
commit 3cdd62bbbf
2 changed files with 21 additions and 14 deletions

View File

@@ -108,7 +108,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
skip_image_resolution: float = 0.0
skip_image_resolution: Optional[Tuple[int, int]] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -245,7 +245,7 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"skip_image_resolution": Any(float, int),
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}
# options handled by argparse but not handled by user config

View File

@@ -687,7 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None,
skip_image_resolution: float = 0.0,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
@@ -776,9 +776,8 @@ class BaseDataset(torch.utils.data.Dataset):
return min_bucket_reso, max_bucket_reso
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
orig_resolution = math.sqrt(image_size[0] * image_size[1])
# skip_image_resolution is exclusive
return self.skip_image_resolution < orig_resolution
return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1]
def update_dataset_image_counts(self):
for subset in self.subsets:
@@ -799,7 +798,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.num_reg_images = num_reg_images
def filter_registered_images_by_orig_resolution(self) -> int:
if self.skip_image_resolution == 0:
if self.skip_image_resolution is None:
return 0
filtered_count = 0
@@ -2014,7 +2013,7 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
skip_image_resolution: Optional[float] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(
resolution,
@@ -2293,7 +2292,7 @@ class FineTuningDataset(BaseDataset):
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
skip_image_resolution: Optional[float] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(
resolution,
@@ -2487,7 +2486,7 @@ class ControlNetDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
skip_image_resolution: float = 0.0,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(
resolution,
@@ -2592,7 +2591,7 @@ class ControlNetDataset(BaseDataset):
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
if self.skip_image_resolution != 0:
if self.skip_image_resolution is not None:
if len(missing_imgs) > 0:
logger.warning(
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
@@ -4735,10 +4734,10 @@ def add_dataset_arguments(
)
parser.add_argument(
"--skip_image_resolution",
type=float,
default=0.0,
help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling"
" / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します",
type=str,
default=None,
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
" / この解像度以下の画像はスキップされます'サイズ'指定、または'幅,高さ'指定)",
)
parser.add_argument(
"--bucket_reso_steps",
@@ -5553,6 +5552,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
len(args.resolution) == 2
), f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.skip_image_resolution is not None:
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
if len(args.skip_image_resolution) == 1:
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
assert (
len(args.skip_image_resolution) == 2
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'','高さ'で指定してください: {args.skip_image_resolution}"
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
assert (