mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: support metadata loading in MemoryEfficientSafeOpen
This commit is contained in:
@@ -261,11 +261,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -276,6 +275,9 @@ class MemoryEfficientSafeOpen:
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
@@ -293,10 +295,9 @@ class MemoryEfficientSafeOpen:
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
|
||||
Reference in New Issue
Block a user