mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Add skip_image_resolution to deduplicate multi-resolution dataset (#2273)
* Add min_orig_resolution and max_orig_resolution * Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution * Change skip_image_resolution to tuple * Move filtering to __init__ * Minor fix
This commit is contained in:
@@ -108,6 +108,7 @@ class BaseDatasetParams:
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -244,6 +245,7 @@ class ConfigSanitizer:
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -256,6 +258,7 @@ class ConfigSanitizer:
|
||||
ARGPARSE_NULLABLE_OPTNAMES = [
|
||||
"face_crop_aug_range",
|
||||
"resolution",
|
||||
"skip_image_resolution",
|
||||
]
|
||||
# prepare map because option name may differ among argparse and user config
|
||||
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
||||
@@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
skip_image_resolution: {dataset.skip_image_resolution}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -687,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -727,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.skip_image_resolution = skip_image_resolution
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
@@ -1915,8 +1918,15 @@ class DreamBoothDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -2034,6 +2044,22 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
filtered_img_paths = []
|
||||
filtered_sizes = []
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
for img_path, size in zip(img_paths, sizes):
|
||||
if size is None: # no latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(img_path)
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
continue
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
@@ -2059,7 +2085,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
captions = [metas[img_path]["caption"] for img_path in img_paths]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
@@ -2200,8 +2226,15 @@ class FineTuningDataset(BaseDataset):
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
@@ -2297,6 +2330,7 @@ class FineTuningDataset(BaseDataset):
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
num_filtered = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
@@ -2355,6 +2389,16 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
size = image_info.image_size
|
||||
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(abs_path)
|
||||
image_info.image_size = size
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
num_filtered += 1
|
||||
continue
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
@@ -2363,6 +2407,8 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
if num_filtered > 0:
|
||||
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2387,8 +2433,15 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
skip_image_resolution: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
super().__init__(
|
||||
resolution,
|
||||
network_multiplier,
|
||||
debug_dataset,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2440,6 +2493,7 @@ class ControlNetDataset(BaseDataset):
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
skip_image_resolution,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -2487,9 +2541,10 @@ class ControlNetDataset(BaseDataset):
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
)
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -4601,6 +4656,13 @@ def add_dataset_arguments(
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_image_resolution",
|
||||
type=str,
|
||||
default=None,
|
||||
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
|
||||
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
@@ -5414,6 +5476,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
len(args.resolution) == 2
|
||||
), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
|
||||
if args.skip_image_resolution is not None:
|
||||
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
|
||||
if len(args.skip_image_resolution) == 1:
|
||||
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
|
||||
assert (
|
||||
len(args.skip_image_resolution) == 2
|
||||
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}"
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
|
||||
assert (
|
||||
|
||||
@@ -1085,6 +1085,7 @@ class NetworkTrainer:
|
||||
"enable_bucket": bool(dataset.enable_bucket),
|
||||
"min_bucket_reso": dataset.min_bucket_reso,
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"skip_image_resolution": dataset.skip_image_resolution,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
@@ -1191,6 +1192,7 @@ class NetworkTrainer:
|
||||
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
||||
"ss_skip_image_resolution": dataset.skip_image_resolution,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
||||
|
||||
Reference in New Issue
Block a user