Fix npz path for verification

This commit is contained in:
Kohya S
2024-08-05 20:26:30 +09:00
parent 002d75179a
commit 231df197dd

View File

@@ -184,20 +184,20 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
logger.error(f"Error loading file: {npz_path}")
raise e
return True