Merge branch 'sd3' into multi-gpu-caching

This commit is contained in:
kohya-ss
2024-10-13 19:14:06 +09:00
5 changed files with 38 additions and 22 deletions

View File

@@ -405,9 +405,14 @@ class LatentsCachingStrategy:
def batch_size(self):
return self._batch_size
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
@property
def cache_suffix(self):
raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError

View File

@@ -199,12 +199,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (

View File

@@ -144,14 +144,10 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
# does not include old npz
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz

View File

@@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (

View File

@@ -1857,9 +1857,30 @@ class DreamBoothDataset(BaseDataset):
strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None:
logger.info("get image size from name of cache files")
# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort()
npz_path_index = 0
size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
l = len(os.path.splitext(img_path)[0]) # remove extension
found = False
while npz_path_index < len(npz_paths): # until found or end of npz_paths
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
if npz_paths[npz_path_index][:l] > img_path[:l]:
break
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
found = True
break
npz_path_index += 1 # next npz_path
if found:
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
else:
w, h = None, None
if w is not None and h is not None:
sizes[i] = [w, h]
size_set_count += 1