feat: reduce memory usage and add memory efficient option for model saving

This commit is contained in:
Kohya S
2024-08-19 22:30:24 +09:00
parent 6e72a799c8
commit 486fe8f70a
4 changed files with 100 additions and 7 deletions

View File

@@ -1,9 +1,12 @@
import logging
import sys
import threading
from typing import *
import json
import struct
import torch
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
@@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
# TODO make inf_utils.py