diff --git a/library/utils.py b/library/utils.py index 07079c6d..4df8bd32 100644 --- a/library/utils.py +++ b/library/utils.py @@ -261,11 +261,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: class MemoryEfficientSafeOpen: - # does not support metadata loading def __init__(self, filename): self.filename = filename - self.header, self.header_size = self._read_header() self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() def __enter__(self): return self @@ -276,6 +275,9 @@ class MemoryEfficientSafeOpen: def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") @@ -293,10 +295,9 @@ class MemoryEfficientSafeOpen: return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): - with open(self.filename, "rb") as f: - header_size = struct.unpack("