Add min_orig_resolution and max_orig_resolution

This commit is contained in:
woctordho
2026-02-20 12:16:43 +08:00
parent 48d368fa55
commit af3a55b21d
3 changed files with 198 additions and 25 deletions

View File

@@ -108,6 +108,8 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
min_orig_resolution: float = 0.0
max_orig_resolution: float = float("inf")
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -118,7 +120,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -244,6 +246,8 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"min_orig_resolution": Any(float, int),
"max_orig_resolution": Any(float, int),
}
# options handled by argparse but not handled by user config
@@ -256,6 +260,8 @@ class ConfigSanitizer:
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
"min_orig_resolution",
"max_orig_resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
@@ -528,6 +534,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
min_orig_resolution: {dataset.min_orig_resolution}
max_orig_resolution: {dataset.max_orig_resolution}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")

View File

@@ -687,6 +687,8 @@ class BaseDataset(torch.utils.data.Dataset):
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None,
min_orig_resolution: float = 0.0,
max_orig_resolution: float = float("inf"),
) -> None:
super().__init__()
@@ -727,6 +729,12 @@ class BaseDataset(torch.utils.data.Dataset):
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
assert (
min_orig_resolution <= max_orig_resolution
), f"min_orig_resolution {min_orig_resolution} cannot be larger than max_orig_resolution {max_orig_resolution}"
self.min_orig_resolution = min_orig_resolution
self.max_orig_resolution = max_orig_resolution
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -772,6 +780,50 @@ class BaseDataset(torch.utils.data.Dataset):
return min_bucket_reso, max_bucket_reso
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
orig_resolution = math.sqrt(image_size[0] * image_size[1])
# min_orig_resolution is exclusive, max_orig_resolution is inclusive
return self.min_orig_resolution < orig_resolution <= self.max_orig_resolution
def has_orig_resolution_filter(self) -> bool:
return not (self.min_orig_resolution == 0.0 and self.max_orig_resolution == float("inf"))
def update_dataset_image_counts(self):
for subset in self.subsets:
subset.img_count = 0
num_train_images = 0
num_reg_images = 0
for image_key, image_info in self.image_data.items():
subset = self.image_to_subset[image_key]
subset.img_count += 1
if image_info.is_reg:
num_reg_images += image_info.num_repeats
else:
num_train_images += image_info.num_repeats
self.num_train_images = num_train_images
self.num_reg_images = num_reg_images
def filter_registered_images_by_orig_resolution(self) -> int:
if not self.has_orig_resolution_filter():
return 0
filtered_count = 0
for image_key, image_info in list(self.image_data.items()):
if self.check_orig_resolution(image_info.image_size):
continue
del self.image_data[image_key]
del self.image_to_subset[image_key]
filtered_count += 1
if filtered_count > 0:
self.update_dataset_image_counts()
return filtered_count
def set_seed(self, seed):
self.seed = seed
@@ -994,6 +1046,10 @@ class BaseDataset(torch.utils.data.Dataset):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
filtered_count = self.filter_registered_images_by_orig_resolution()
if filtered_count > 0:
logger.info(f"filtered {filtered_count} images by original resolution")
# # run in parallel
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
# with ThreadPoolExecutor(max_workers) as executor:
@@ -1895,6 +1951,57 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
def register_regularization_images(
self, reg_infos: Sequence[Tuple[ImageInfo, DreamBoothSubset]], num_train_images: int
) -> None:
if len(reg_infos) == 0 or num_train_images <= 0:
return
n = 0
first_loop = True
while n < num_train_images:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1
n += 1
if n >= num_train_images:
break
first_loop = False
def rebalance_regularization_images(self):
if not self.is_training_dataset:
return
reg_infos = []
for image_key, image_info in list(self.image_data.items()):
if not image_info.is_reg:
continue
reg_infos.append((image_info, self.image_to_subset[image_key]))
del self.image_data[image_key]
del self.image_to_subset[image_key]
num_train_images = sum(info.num_repeats for info in self.image_data.values())
if len(reg_infos) == 0:
return
for info, subset in reg_infos:
info.num_repeats = subset.num_repeats
self.register_regularization_images(reg_infos, num_train_images)
def filter_registered_images_by_orig_resolution(self) -> int:
filtered_count = super().filter_registered_images_by_orig_resolution()
if filtered_count > 0 and self.is_training_dataset:
self.rebalance_regularization_images()
self.update_dataset_image_counts()
return filtered_count
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
@@ -1915,8 +2022,17 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
min_orig_resolution: Optional[float] = None,
max_orig_resolution: Optional[float] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2059,7 +2175,7 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
captions = [metas[img_path]["caption"] for img_path in img_paths]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
@@ -2166,20 +2282,7 @@ class DreamBoothDataset(BaseDataset):
if num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
first_loop = False
self.register_regularization_images(reg_infos, num_train_images)
self.num_reg_images = num_reg_images
@@ -2200,8 +2303,17 @@ class FineTuningDataset(BaseDataset):
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
min_orig_resolution: Optional[float] = None,
max_orig_resolution: Optional[float] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
)
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
@@ -2387,8 +2499,17 @@ class ControlNetDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
min_orig_resolution: float = 0.0,
max_orig_resolution: float = float("inf"),
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
)
db_subsets = []
for subset in subsets:
@@ -2440,6 +2561,8 @@ class ControlNetDataset(BaseDataset):
validation_split,
validation_seed,
resize_interpolation,
min_orig_resolution,
max_orig_resolution,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2484,12 +2607,25 @@ class ControlNetDataset(BaseDataset):
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
if self.has_orig_resolution_filter():
if len(missing_imgs) > 0:
logger.warning(
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
+ f" / 元画像解像度フィルタが有効なため、{len(missing_imgs)}枚の不足した制御用画像を無視します"
)
if len(extra_imgs) > 0:
logger.warning(
f"ignore {len(extra_imgs)} extra conditioning images because original-resolution filtering is enabled"
+ f" / 元画像解像度フィルタが有効なため、{len(extra_imgs)}枚の余分な制御用画像を無視します"
)
# Later in `make_buckets` we assert `len(missing_imgs) == 0` but still ignore `extra_imgs`
else:
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -2498,8 +2634,19 @@ class ControlNetDataset(BaseDataset):
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
missing_imgs = []
for info in self.dreambooth_dataset_delegate.image_data.values():
if info.cond_img_path is None:
missing_imgs.append(os.path.splitext(os.path.basename(info.absolute_path))[0])
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
@@ -4601,6 +4748,20 @@ def add_dataset_arguments(
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--min_orig_resolution",
type=float,
default=0.0,
help="minimum original resolution for images (exclusive), defined by sqrt(width * height) before scaling"
" / 画像の元解像度の下限排他的。リサイズ前のsqrt(width * height)で判定します",
)
parser.add_argument(
"--max_orig_resolution",
type=float,
default=float("inf"),
help="maximum original resolution for images (inclusive), defined by sqrt(width * height) before scaling"
" / 画像の元解像度の上限包含的。リサイズ前のsqrt(width * height)で判定します",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,

View File

@@ -1085,6 +1085,8 @@ class NetworkTrainer:
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"min_orig_resolution": dataset.min_orig_resolution,
"max_orig_resolution": dataset.max_orig_resolution,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
@@ -1191,6 +1193,8 @@ class NetworkTrainer:
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_min_orig_resolution": dataset.min_orig_resolution,
"ss_max_orig_resolution": dataset.max_orig_resolution,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),