feat: reduce memory usage and add memory efficient option for model saving

This commit is contained in:
Kohya S
2024-08-19 22:30:24 +09:00
parent 6e72a799c8
commit 486fe8f70a
4 changed files with 100 additions and 7 deletions

View File

@@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows: The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 19, 2024:
In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason.
An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code.
Aug 18, 2024: Aug 18, 2024:
Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.

View File

@@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser:
add_custom_train_arguments(parser) # TODO remove this from here add_custom_train_arguments(parser) # TODO remove this from here
flux_train_utils.add_flux_train_arguments(parser) flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)
parser.add_argument( parser.add_argument(
"--fused_optimizer_groups", "--fused_optimizer_groups",
type=int, type=int,

View File

@@ -20,7 +20,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from .utils import setup_logging from .utils import setup_logging, mem_eff_save_file
setup_logging() setup_logging()
import logging import logging
@@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
return model_pred, weighting return model_pred, weighting
def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {} state_dict = {}
def update_sd(prefix, sd): def update_sd(prefix, sd):
for k, v in sd.items(): for k, v in sd.items():
key = prefix + k key = prefix + k
if save_dtype is not None: if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype) v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
update_sd("", flux.state_dict()) update_sd("", flux.state_dict())
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata) save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_flux_model_on_train_end( def save_flux_model_on_train_end(
@@ -429,7 +438,7 @@ def save_flux_model_on_train_end(
): ):
def sd_saver(ckpt_file, epoch_no, global_step): def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype) save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise(
): ):
def sd_saver(ckpt_file, epoch_no, global_step): def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype) save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common( train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args, args,

View File

@@ -1,9 +1,12 @@
import logging import logging
import sys import sys
import threading import threading
from typing import *
import json
import struct
import torch import torch
from torchvision import transforms from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
@@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init) logger.info(msg_init)
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
# TODO make inf_utils.py # TODO make inf_utils.py