mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
WIP: update new latents caching
This commit is contained in:
@@ -360,11 +360,23 @@ class AugHelper:
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -373,10 +385,15 @@ class LatentsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str):
|
||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -1034,7 +1051,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
if not is_main_process: # prepare for multi-gpu, only store to info
|
||||
continue
|
||||
|
||||
@@ -1730,6 +1747,18 @@ class DreamBoothDataset(BaseDataset):
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None] * len(img_paths)
|
||||
|
||||
# new caching: get image size from cache files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
if strategy is not None:
|
||||
logger.info("get image size from cache files")
|
||||
size_set_count = 0
|
||||
for i, img_path in enumerate(tqdm(img_paths)):
|
||||
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
|
||||
if w is not None and h is not None:
|
||||
sizes[i] = [w, h]
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
@@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3(
|
||||
b_lg_out = b_lg_out.detach()
|
||||
b_t5_out = b_t5_out.detach()
|
||||
b_pool = b_pool.detach()
|
||||
|
||||
|
||||
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
||||
# debug: NaN check
|
||||
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
|
||||
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
|
||||
|
||||
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user