mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
Change skip_image_resolution to tuple
This commit is contained in:
@@ -108,7 +108,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
skip_image_resolution: float = 0.0
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -245,7 +245,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
"skip_image_resolution": Any(float, int),
|
||||
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
|
||||
@@ -687,7 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: float = 0.0,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -776,9 +776,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return min_bucket_reso, max_bucket_reso
|
||||
|
||||
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
|
||||
orig_resolution = math.sqrt(image_size[0] * image_size[1])
|
||||
# skip_image_resolution is exclusive
|
||||
return self.skip_image_resolution < orig_resolution
|
||||
return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1]
|
||||
|
||||
def update_dataset_image_counts(self):
|
||||
for subset in self.subsets:
|
||||
@@ -799,7 +798,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def filter_registered_images_by_orig_resolution(self) -> int:
|
||||
if self.skip_image_resolution == 0:
|
||||
if self.skip_image_resolution is None:
|
||||
return 0
|
||||
|
||||
filtered_count = 0
|
||||
@@ -2014,7 +2013,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[float] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
@@ -2293,7 +2292,7 @@ class FineTuningDataset(BaseDataset):
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[float] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
@@ -2487,7 +2486,7 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: float = 0.0,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
resolution,
|
||||
@@ -2592,7 +2591,7 @@ class ControlNetDataset(BaseDataset):
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
if self.skip_image_resolution != 0:
|
||||
if self.skip_image_resolution is not None:
|
||||
if len(missing_imgs) > 0:
|
||||
logger.warning(
|
||||
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
|
||||
@@ -4735,10 +4734,10 @@ def add_dataset_arguments(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_image_resolution",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="images not larger than this resolution will be skipped, defined by sqrt(width * height) before scaling"
|
||||
" / この解像度以下の画像はスキップされます。リサイズ前のsqrt(width * height)で判定します",
|
||||
type=str,
|
||||
default=None,
|
||||
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
|
||||
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
@@ -5553,6 +5552,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
len(args.resolution) == 2
|
||||
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
|
||||
if args.skip_image_resolution is not None:
|
||||
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
|
||||
if len(args.skip_image_resolution) == 1:
|
||||
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
|
||||
assert (
|
||||
len(args.skip_image_resolution) == 2
|
||||
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}"
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user