mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Fix npz path for verification
This commit is contained in:
@@ -184,20 +184,20 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, abs_path: str):
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(self.get_outputs_npz_path(abs_path))
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user