mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution
This commit is contained in:
@@ -108,8 +108,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
min_orig_resolution: float = 0.0
|
||||
max_orig_resolution: float = float("inf")
|
||||
skip_image_resolution: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -246,8 +245,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
"min_orig_resolution": Any(float, int),
|
||||
"max_orig_resolution": Any(float, int),
|
||||
"skip_image_resolution": Any(float, int),
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -260,8 +258,7 @@ class ConfigSanitizer:
|
||||
ARGPARSE_NULLABLE_OPTNAMES = [
|
||||
"face_crop_aug_range",
|
||||
"resolution",
|
||||
"min_orig_resolution",
|
||||
"max_orig_resolution",
|
||||
"skip_image_resolution",
|
||||
]
|
||||
# prepare map because option name may differ among argparse and user config
|
||||
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
||||
@@ -534,8 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
min_orig_resolution: {dataset.min_orig_resolution}
|
||||
max_orig_resolution: {dataset.max_orig_resolution}
|
||||
skip_image_resolution: {dataset.skip_image_resolution}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -687,8 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
min_orig_resolution: float = 0.0,
|
||||
max_orig_resolution: float = float("inf"),
|
||||
skip_image_resolution: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -729,11 +728,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
assert (
|
||||
min_orig_resolution <= max_orig_resolution
|
||||
), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}"
|
||||
self.min_orig_resolution = min_orig_resolution
|
||||
self.max_orig_resolution = max_orig_resolution
|
||||
self.skip_image_resolution = skip_image_resolution
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
@@ -782,11 +777,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
|
||||
orig_resolution = math.sqrt(image_size[0] * image_size[1])
|
||||
# min_orig_resolution is exclusive, max_orig_resolution is inclusive
|
||||
return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution
|
||||
|
||||
def has_orig_resolution_filter(self) -> bool:
|
||||
return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf"))
|
||||
# skip_image_resolution is exclusive
|
||||
return self.skip_image_resolution < orig_resolution
|
||||
|
||||
def update_dataset_image_counts(self):
|
||||
for subset in self.subsets:
|
||||
@@ -807,7 +799,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def filter_registered_images_by_orig_resolution(self) -> int:
|
||||
if not self.has_orig_resolution_filter():
|
||||
if self.skip_image_resolution == 0:
|
||||
return 0
|
||||
|
||||
filtered_count = 0
|
||||
@@ -2022,16 +2014,14 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
min_orig_resolution: Optional[float] = None,
|
||||
max_orig_resolution: Optional[float] = None,
|
||||
skip_image_resolution: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
min_orig_resolution,
|
||||
max_orig_resolution,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
@@ -2303,16 +2293,14 @@ class FineTuningDataset(BaseDataset):
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
min_orig_resolution: Optional[float] = None,
|
||||
max_orig_resolution: Optional[float] = None,
|
||||
skip_image_resolution: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
min_orig_resolution,
|
||||
max_orig_resolution,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
@@ -2499,16 +2487,14 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
min_orig_resolution: float = 0.0,
|
||||
max_orig_resolution: float = float("inf"),
|
||||
skip_image_resolution: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
min_orig_resolution,
|
||||
max_orig_resolution,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
db_subsets = []
|
||||
@@ -2561,8 +2547,7 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
min_orig_resolution,
|
||||
max_orig_resolution,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -2607,7 +2592,7 @@ class ControlNetDataset(BaseDataset):
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
if self.has_orig_resolution_filter():
|
||||
if self.skip_image_resolution != 0:
|
||||
if len(missing_imgs) > 0:
|
||||
logger.warning(
|
||||
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
|
||||
@@ -4749,18 +4734,11 @@ def add_dataset_arguments(
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_orig_resolution",
|
||||
"--skip_image_resolution",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling"
|
||||
" / 画像の元解像度の下限(排他的)。リサイズ前のsqrt(width * height)で判定します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_orig_resolution",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling"
|
||||
" / 画像の元解像度の上限(包含的)。リサイズ前のsqrt(width * height)で判定します",
|
||||
help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling"
|
||||
" / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
|
||||
@@ -1085,8 +1085,7 @@ class NetworkTrainer:
|
||||
"enable_bucket": bool(dataset.enable_bucket),
|
||||
"min_bucket_reso": dataset.min_bucket_reso,
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"min_orig_resolution": dataset.min_orig_resolution,
|
||||
"max_orig_resolution": dataset.max_orig_resolution,
|
||||
"skip_image_resolution": dataset.skip_image_resolution,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
@@ -1193,8 +1192,7 @@ class NetworkTrainer:
|
||||
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
||||
"ss_min_orig_resolution": dataset.min_orig_resolution,
|
||||
"ss_max_orig_resolution": dataset.max_orig_resolution,
|
||||
"ss_skip_image_resolution": dataset.skip_image_resolution,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
||||
|
||||
Reference in New Issue
Block a user