mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge branch 'sd3' into multi-gpu-caching
This commit is contained in:
@@ -405,9 +405,14 @@ class LatentsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -199,12 +199,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
|
||||
@@ -144,14 +144,10 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
# does not include old npz
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
|
||||
@@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
|
||||
@@ -1857,9 +1857,30 @@ class DreamBoothDataset(BaseDataset):
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
if strategy is not None:
|
||||
logger.info("get image size from name of cache files")
|
||||
|
||||
# make image path to npz path mapping
|
||||
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
|
||||
npz_paths.sort()
|
||||
npz_path_index = 0
|
||||
|
||||
size_set_count = 0
|
||||
for i, img_path in enumerate(tqdm(img_paths)):
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
|
||||
l = len(os.path.splitext(img_path)[0]) # remove extension
|
||||
found = False
|
||||
while npz_path_index < len(npz_paths): # until found or end of npz_paths
|
||||
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
|
||||
if npz_paths[npz_path_index][:l] > img_path[:l]:
|
||||
break
|
||||
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
|
||||
found = True
|
||||
break
|
||||
npz_path_index += 1 # next npz_path
|
||||
|
||||
if found:
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
|
||||
else:
|
||||
w, h = None, None
|
||||
|
||||
if w is not None and h is not None:
|
||||
sizes[i] = [w, h]
|
||||
size_set_count += 1
|
||||
|
||||
Reference in New Issue
Block a user