From 81411a398eb4ce28d84cc2da8238ff013d40d62f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 22:02:29 +0900 Subject: [PATCH] speed up getting image sizes --- library/strategy_base.py | 7 ++++++- library/strategy_flux.py | 9 +++------ library/strategy_sd.py | 12 ++++-------- library/strategy_sd3.py | 9 +++------ library/train_util.py | 23 ++++++++++++++++++++++- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97e..6a01c30a 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -204,9 +204,14 @@ class LatentsCachingStrategy: def batch_size(self): return self._batch_size - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + @property + def cache_suffix(self): raise NotImplementedError + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: raise NotImplementedError diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8d..887113ca 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -189,12 +189,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31..af472e49 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -108,14 +108,10 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): self.suffix = ( SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX ) - - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - # does not include old npz - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + + @property + def cache_suffix(self) -> str: + return self.suffix def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: # support old .npz diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a2281890..9fde0208 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/train_util.py b/library/train_util.py index 989758ad..dcc01f6f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1739,9 +1739,30 @@ class DreamBoothDataset(BaseDataset): strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort() + npz_path_index = 0 + size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_disk_cache_path(img_path) + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1