speed up getting image sizes

This commit is contained in:
Kohya S
2024-08-22 22:02:29 +09:00
parent 99744af53a
commit 81411a398e
5 changed files with 38 additions and 22 deletions

View File

@@ -204,9 +204,14 @@ class LatentsCachingStrategy:
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: @property
def cache_suffix(self):
raise NotImplementedError raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError raise NotImplementedError

View File

@@ -189,12 +189,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: @property
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) def cache_suffix(self) -> str:
if len(npz_file) == 0: return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return (

View File

@@ -108,14 +108,10 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
self.suffix = ( self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
) )
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: @property
# does not include old npz def cache_suffix(self) -> str:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) return self.suffix
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz # support old .npz

View File

@@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: @property
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) def cache_suffix(self) -> str:
if len(npz_file) == 0: return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return (

View File

@@ -1739,9 +1739,30 @@ class DreamBoothDataset(BaseDataset):
strategy = LatentsCachingStrategy.get_strategy() strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None: if strategy is not None:
logger.info("get image size from name of cache files") logger.info("get image size from name of cache files")
# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort()
npz_path_index = 0
size_set_count = 0 size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)): for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_disk_cache_path(img_path) l = len(os.path.splitext(img_path)[0]) # remove extension
found = False
while npz_path_index < len(npz_paths): # until found or end of npz_paths
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
if npz_paths[npz_path_index][:l] > img_path[:l]:
break
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
found = True
break
npz_path_index += 1 # next npz_path
if found:
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
else:
w, h = None, None
if w is not None and h is not None: if w is not None and h is not None:
sizes[i] = [w, h] sizes[i] = [w, h]
size_set_count += 1 size_set_count += 1