diff --git a/library/train_util.py b/library/train_util.py index 5b9d0615..c96e0b85 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -775,46 +775,6 @@ class BaseDataset(torch.utils.data.Dataset): return min_bucket_reso, max_bucket_reso - def check_orig_resolution(self, image_size: Tuple[int, int]) -> bool: - # skip_image_resolution is exclusive - return self.skip_image_resolution[0] * self.skip_image_resolution[1] < image_size[0] * image_size[1] - - def update_dataset_image_counts(self): - for subset in self.subsets: - subset.img_count = 0 - - num_train_images = 0 - num_reg_images = 0 - for image_key, image_info in self.image_data.items(): - subset = self.image_to_subset[image_key] - subset.img_count += 1 - - if image_info.is_reg: - num_reg_images += image_info.num_repeats - else: - num_train_images += image_info.num_repeats - - self.num_train_images = num_train_images - self.num_reg_images = num_reg_images - - def filter_registered_images_by_orig_resolution(self) -> int: - if self.skip_image_resolution is None: - return 0 - - filtered_count = 0 - for image_key, image_info in list(self.image_data.items()): - if self.check_orig_resolution(image_info.image_size): - continue - - del self.image_data[image_key] - del self.image_to_subset[image_key] - filtered_count += 1 - - if filtered_count > 0: - self.update_dataset_image_counts() - - return filtered_count - def set_seed(self, seed): self.seed = seed @@ -1037,10 +997,6 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) - filtered_count = self.filter_registered_images_by_orig_resolution() - if filtered_count > 0: - logger.info(f"filtered {filtered_count} images by original resolution") - # # run in parallel # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) # with ThreadPoolExecutor(max_workers) as executor: @@ -1942,57 +1898,6 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - def register_regularization_images( - self, reg_infos: Sequence[Tuple[ImageInfo, DreamBoothSubset]], num_train_images: int - ) -> None: - if len(reg_infos) == 0 or num_train_images <= 0: - return - - n = 0 - first_loop = True - while n < num_train_images: - for info, subset in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False - - def rebalance_regularization_images(self): - if not self.is_training_dataset: - return - - reg_infos = [] - for image_key, image_info in list(self.image_data.items()): - if not image_info.is_reg: - continue - - reg_infos.append((image_info, self.image_to_subset[image_key])) - del self.image_data[image_key] - del self.image_to_subset[image_key] - - num_train_images = sum(info.num_repeats for info in self.image_data.values()) - if len(reg_infos) == 0: - return - - for info, subset in reg_infos: - info.num_repeats = subset.num_repeats - - self.register_regularization_images(reg_infos, num_train_images) - - def filter_registered_images_by_orig_resolution(self) -> int: - filtered_count = super().filter_registered_images_by_orig_resolution() - - if filtered_count > 0 and self.is_training_dataset: - self.rebalance_regularization_images() - self.update_dataset_image_counts() - - return filtered_count - # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset @@ -2139,6 +2044,22 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.skip_image_resolution is not None: + filtered_img_paths = [] + filtered_sizes = [] + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + for img_path, size in zip(img_paths, sizes): + if size is None: # no latents cache file, get image size by reading image file (slow) + size = self.get_image_size(img_path) + if size[0] * size[1] <= skip_image_area: + continue + filtered_img_paths.append(img_path) + filtered_sizes.append(size) + if len(filtered_img_paths) < len(img_paths): + logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + img_paths = filtered_img_paths + sizes = filtered_sizes + # We want to create a training and validation split. This should be improved in the future # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets @@ -2271,7 +2192,20 @@ class DreamBoothDataset(BaseDataset): if num_reg_images == 0: logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: - self.register_regularization_images(reg_infos, num_train_images) + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info, subset in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 # rewrite registered info + n += 1 + if n >= num_train_images: + break + first_loop = False self.num_reg_images = num_reg_images @@ -2396,6 +2330,7 @@ class FineTuningDataset(BaseDataset): tags_list = [] size_set_from_metadata = 0 size_set_from_cache_filename = 0 + num_filtered = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] caption = img_md.get("caption") @@ -2454,6 +2389,16 @@ class FineTuningDataset(BaseDataset): image_info.image_size = (w, h) size_set_from_cache_filename += 1 + if self.skip_image_resolution is not None: + size = image_info.image_size + if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow) + size = self.get_image_size(abs_path) + image_info.image_size = size + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + if size[0] * size[1] <= skip_image_area: + num_filtered += 1 + continue + self.register_image(image_info, subset) if size_set_from_cache_filename > 0: @@ -2462,6 +2407,8 @@ class FineTuningDataset(BaseDataset): ) if size_set_from_metadata > 0: logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") + if num_filtered > 0: + logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag @@ -2591,25 +2538,13 @@ class ControlNetDataset(BaseDataset): conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - if self.skip_image_resolution is not None: - if len(missing_imgs) > 0: - logger.warning( - f"ignore {len(missing_imgs)} missing conditioning images because original-resolution filtering is enabled" - + f" / 元画像解像度フィルタが有効なため、{len(missing_imgs)}枚の不足した制御用画像を無視します" - ) - if len(extra_imgs) > 0: - logger.warning( - f"ignore {len(extra_imgs)} extra conditioning images because original-resolution filtering is enabled" - + f" / 元画像解像度フィルタが有効なため、{len(extra_imgs)}枚の余分な制御用画像を無視します" - ) - # Later in `make_buckets` we assert `len(missing_imgs) == 0` but still ignore `extra_imgs` - else: - assert ( - len(missing_imgs) == 0 - ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - assert ( - len(extra_imgs) == 0 - ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + if len(extra_imgs) > 0: + logger.warning( + f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + ) self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -2619,18 +2554,8 @@ class ControlNetDataset(BaseDataset): def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() - missing_imgs = [] - for info in self.dreambooth_dataset_delegate.image_data.values(): - if info.cond_img_path is None: - missing_imgs.append(os.path.splitext(os.path.basename(info.absolute_path))[0]) - assert ( - len(missing_imgs) == 0 - ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)