mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Move filtering to __init__
This commit is contained in:
@@ -775,46 +775,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
return min_bucket_reso, max_bucket_reso
|
||||
|
||||
def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool:
|
||||
# skip_image_resolution is exclusive
|
||||
return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1]
|
||||
|
||||
def update_dataset_image_counts(self):
|
||||
for subset in self.subsets:
|
||||
subset.img_count = 0
|
||||
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
for image_key, image_info in self.image_data.items():
|
||||
subset = self.image_to_subset[image_key]
|
||||
subset.img_count += 1
|
||||
|
||||
if image_info.is_reg:
|
||||
num_reg_images += image_info.num_repeats
|
||||
else:
|
||||
num_train_images += image_info.num_repeats
|
||||
|
||||
self.num_train_images = num_train_images
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def filter_registered_images_by_orig_resolution(self) -> int:
|
||||
if self.skip_image_resolution is None:
|
||||
return 0
|
||||
|
||||
filtered_count = 0
|
||||
for image_key, image_info in list(self.image_data.items()):
|
||||
if self.check_orig_resolution(image_info.image_size):
|
||||
continue
|
||||
|
||||
del self.image_data[image_key]
|
||||
del self.image_to_subset[image_key]
|
||||
filtered_count += 1
|
||||
|
||||
if filtered_count > 0:
|
||||
self.update_dataset_image_counts()
|
||||
|
||||
return filtered_count
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
@@ -1037,10 +997,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if info.image_size is None:
|
||||
info.image_size = self.get_image_size(info.absolute_path)
|
||||
|
||||
filtered_count = self.filter_registered_images_by_orig_resolution()
|
||||
if filtered_count > 0:
|
||||
logger.info(f"filtered {filtered_count} images by original resolution")
|
||||
|
||||
# # run in parallel
|
||||
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
|
||||
# with ThreadPoolExecutor(max_workers) as executor:
|
||||
@@ -1942,57 +1898,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
def register_regularization_images(
|
||||
self, reg_infos: Sequence[Tuple[ImageInfo, DreamBoothSubset]], num_train_images: int
|
||||
) -> None:
|
||||
if len(reg_infos) == 0 or num_train_images <= 0:
|
||||
return
|
||||
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
def rebalance_regularization_images(self):
|
||||
if not self.is_training_dataset:
|
||||
return
|
||||
|
||||
reg_infos = []
|
||||
for image_key, image_info in list(self.image_data.items()):
|
||||
if not image_info.is_reg:
|
||||
continue
|
||||
|
||||
reg_infos.append((image_info, self.image_to_subset[image_key]))
|
||||
del self.image_data[image_key]
|
||||
del self.image_to_subset[image_key]
|
||||
|
||||
num_train_images = sum(info.num_repeats for info in self.image_data.values())
|
||||
if len(reg_infos) == 0:
|
||||
return
|
||||
|
||||
for info, subset in reg_infos:
|
||||
info.num_repeats = subset.num_repeats
|
||||
|
||||
self.register_regularization_images(reg_infos, num_train_images)
|
||||
|
||||
def filter_registered_images_by_orig_resolution(self) -> int:
|
||||
filtered_count = super().filter_registered_images_by_orig_resolution()
|
||||
|
||||
if filtered_count > 0 and self.is_training_dataset:
|
||||
self.rebalance_regularization_images()
|
||||
self.update_dataset_image_counts()
|
||||
|
||||
return filtered_count
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
@@ -2139,6 +2044,22 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
filtered_img_paths = []
|
||||
filtered_sizes = []
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
for img_path, size in zip(img_paths, sizes):
|
||||
if size is None: # no latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(img_path)
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
continue
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
@@ -2271,7 +2192,20 @@ class DreamBoothDataset(BaseDataset):
|
||||
if num_reg_images == 0:
|
||||
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
self.register_regularization_images(reg_infos, num_train_images)
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
@@ -2396,6 +2330,7 @@ class FineTuningDataset(BaseDataset):
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
num_filtered = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
@@ -2454,6 +2389,16 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
size = image_info.image_size
|
||||
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
|
||||
size = self.get_image_size(abs_path)
|
||||
image_info.image_size = size
|
||||
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
|
||||
if size[0] * size[1] <= skip_image_area:
|
||||
num_filtered += 1
|
||||
continue
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
@@ -2462,6 +2407,8 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
if num_filtered > 0:
|
||||
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2591,25 +2538,13 @@ class ControlNetDataset(BaseDataset):
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
if self.skip_image_resolution is not None:
|
||||
if len(missing_imgs) > 0:
|
||||
logger.warning(
|
||||
f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled"
|
||||
+ f" / 元画像解像度フィルタが有効なため、{len(missing_imgs)}枚の不足した制御用画像を無視します"
|
||||
)
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"ignore {len(extra_imgs)} extra conditioning images because original-resolution filtering is enabled"
|
||||
+ f" / 元画像解像度フィルタが有効なため、{len(extra_imgs)}枚の余分な制御用画像を無視します"
|
||||
)
|
||||
# Later in `make_buckets` we assert `len(missing_imgs) == 0` but still ignore `extra_imgs`
|
||||
else:
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
)
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -2619,18 +2554,8 @@ class ControlNetDataset(BaseDataset):
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
|
||||
missing_imgs = []
|
||||
for info in self.dreambooth_dataset_delegate.image_data.values():
|
||||
if info.cond_img_path is None:
|
||||
missing_imgs.append(os.path.splitext(os.path.basename(info.absolute_path))[0])
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
Reference in New Issue
Block a user