mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: support metadata loading in MemoryEfficientSafeOpen
This commit is contained in:
@@ -261,11 +261,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
|||||||
|
|
||||||
|
|
||||||
class MemoryEfficientSafeOpen:
|
class MemoryEfficientSafeOpen:
|
||||||
# does not support metadata loading
|
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.header, self.header_size = self._read_header()
|
|
||||||
self.file = open(filename, "rb")
|
self.file = open(filename, "rb")
|
||||||
|
self.header, self.header_size = self._read_header()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -276,6 +275,9 @@ class MemoryEfficientSafeOpen:
|
|||||||
def keys(self):
|
def keys(self):
|
||||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||||
|
|
||||||
|
def metadata(self) -> Dict[str, str]:
|
||||||
|
return self.header.get("__metadata__", {})
|
||||||
|
|
||||||
def get_tensor(self, key):
|
def get_tensor(self, key):
|
||||||
if key not in self.header:
|
if key not in self.header:
|
||||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||||
@@ -293,10 +295,9 @@ class MemoryEfficientSafeOpen:
|
|||||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||||
|
|
||||||
def _read_header(self):
|
def _read_header(self):
|
||||||
with open(self.filename, "rb") as f:
|
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
header_json = self.file.read(header_size).decode("utf-8")
|
||||||
header_json = f.read(header_size).decode("utf-8")
|
return json.loads(header_json), header_size
|
||||||
return json.loads(header_json), header_size
|
|
||||||
|
|
||||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||||
|
|||||||
Reference in New Issue
Block a user