From f4a004786500d80e1b47728d216aed9d76869a9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:50:44 +0900 Subject: [PATCH] feat: support metadata loading in MemoryEfficientSafeOpen --- library/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 07079c6d..4df8bd32 100644 --- a/library/utils.py +++ b/library/utils.py @@ -261,11 +261,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: class MemoryEfficientSafeOpen: - # does not support metadata loading def __init__(self, filename): self.filename = filename - self.header, self.header_size = self._read_header() self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() def __enter__(self): return self @@ -276,6 +275,9 @@ class MemoryEfficientSafeOpen: def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") @@ -293,10 +295,9 @@ class MemoryEfficientSafeOpen: return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): - with open(self.filename, "rb") as f: - header_size = struct.unpack("