Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution

This commit is contained in:
woctordho
2026-02-20 22:59:40 +08:00
parent af3a55b21d
commit 47afa8b2bd
3 changed files with 22 additions and 50 deletions

View File

@@ -108,8 +108,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
min_orig_resolution: float = 0.0
max_orig_resolution: float = float("inf")
skip_image_resolution: float = 0.0
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -246,8 +245,7 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"min_orig_resolution": Any(float, int),
"max_orig_resolution": Any(float, int),
"skip_image_resolution": Any(float, int),
}
# options handled by argparse but not handled by user config
@@ -260,8 +258,7 @@ class ConfigSanitizer:
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
"min_orig_resolution",
"max_orig_resolution",
"skip_image_resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
@@ -534,8 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
min_orig_resolution: {dataset.min_orig_resolution}
max_orig_resolution: {dataset.max_orig_resolution}
skip_image_resolution: {dataset.skip_image_resolution}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")

View File

@@ -687,8 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None,
min_orig_resolution: float = 0.0,
max_orig_resolution: float = float("inf"),
skip_image_resolution: float = 0.0,
) -> None:
super().__init__()
@@ -729,11 +728,7 @@ class BaseDataset(torch.utils.data.Dataset):
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
assert (
min_orig_resolution <= max_orig_resolution
), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}"
self.min_orig_resolution = min_orig_resolution
self.max_orig_resolution = max_orig_resolution
self.skip_image_resolution = skip_image_resolution
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -782,11 +777,8 @@ class BaseDataset(torch.utils.data.Dataset):
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
orig_resolution = math.sqrt(image_size[0] * image_size[1])
# min_orig_resolution is exclusive, max_orig_resolution is inclusive
return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution
def has_orig_resolution_filter(self) -> bool:
return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf"))
# skip_image_resolution is exclusive
return self.skip_image_resolution < orig_resolution
def update_dataset_image_counts(self):
for subset in self.subsets:
@@ -807,7 +799,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.num_reg_images = num_reg_images
def filter_registered_images_by_orig_resolution(self) -> int:
if not self.has_orig_resolution_filter():
if self.skip_image_resolution == 0:
return 0
filtered_count = 0
@@ -2022,16 +2014,14 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
min_orig_resolution: Optional[float] = None,
max_orig_resolution: Optional[float] = None,
skip_image_resolution: Optional[float] = None,
) -> None:
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
skip_image_resolution,
)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2303,16 +2293,14 @@ class FineTuningDataset(BaseDataset):
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
min_orig_resolution: Optional[float] = None,
max_orig_resolution: Optional[float] = None,
skip_image_resolution: Optional[float] = None,
) -> None:
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
skip_image_resolution,
)
self.batch_size = batch_size
@@ -2499,16 +2487,14 @@ class ControlNetDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
min_orig_resolution: float = 0.0,
max_orig_resolution: float = float("inf"),
skip_image_resolution: float = 0.0,
) -> None:
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
skip_image_resolution,
)
db_subsets = []
@@ -2561,8 +2547,7 @@ class ControlNetDataset(BaseDataset):
validation_split,
validation_seed,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
skip_image_resolution,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2607,7 +2592,7 @@ class ControlNetDataset(BaseDataset):
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
if self.has_orig_resolution_filter():
if self.skip_image_resolution != 0:
if len(missing_imgs) > 0:
logger.warning(
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
@@ -4749,18 +4734,11 @@ def add_dataset_arguments(
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--min_orig_resolution",
"--skip_image_resolution",
type=float,
default=0.0,
help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling"
" / 画像の元解像度の下限(排他的)。リサイズ前のsqrt(width * height)で判定します",
)
parser.add_argument(
"--max_orig_resolution",
type=float,
default=float("inf"),
help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling"
" / 画像の元解像度の上限包含的。リサイズ前のsqrt(width * height)で判定します",
help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling"
" / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します",
)
parser.add_argument(
"--bucket_reso_steps",

View File

@@ -1085,8 +1085,7 @@ class NetworkTrainer:
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"min_orig_resolution": dataset.min_orig_resolution,
"max_orig_resolution": dataset.max_orig_resolution,
"skip_image_resolution": dataset.skip_image_resolution,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
@@ -1193,8 +1192,7 @@ class NetworkTrainer:
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_min_orig_resolution": dataset.min_orig_resolution,
"ss_max_orig_resolution": dataset.max_orig_resolution,
"ss_skip_image_resolution": dataset.skip_image_resolution,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),