From 231df197ddf4372b3d90751146927f33e1965d1a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:26:30 +0900 Subject: [PATCH] Fix npz path for verification --- library/strategy_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index a4513336..3eb0ab6f 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -184,20 +184,20 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) + npz = np.load(npz_path) if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: return False except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True