mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: reduce memory usage and add memory efficient option for model saving
This commit is contained in:
@@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 19, 2024:
|
||||||
|
In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason.
|
||||||
|
|
||||||
|
An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code.
|
||||||
|
|
||||||
Aug 18, 2024:
|
Aug 18, 2024:
|
||||||
Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
|
Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.
|
||||||
|
|
||||||
|
|||||||
@@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
add_custom_train_arguments(parser) # TODO remove this from here
|
add_custom_train_arguments(parser) # TODO remove this from here
|
||||||
flux_train_utils.add_flux_train_arguments(parser)
|
flux_train_utils.add_flux_train_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem_eff_save",
|
||||||
|
action="store_true",
|
||||||
|
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fused_optimizer_groups",
|
"--fused_optimizer_groups",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
|||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging, mem_eff_save_file
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
|||||||
return model_pred, weighting
|
return model_pred, weighting
|
||||||
|
|
||||||
|
|
||||||
def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None):
|
def save_models(
|
||||||
|
ckpt_path: str,
|
||||||
|
flux: flux_models.Flux,
|
||||||
|
sai_metadata: Optional[dict],
|
||||||
|
save_dtype: Optional[torch.dtype] = None,
|
||||||
|
use_mem_eff_save: bool = False,
|
||||||
|
):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
def update_sd(prefix, sd):
|
def update_sd(prefix, sd):
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
key = prefix + k
|
key = prefix + k
|
||||||
if save_dtype is not None:
|
if save_dtype is not None and v.dtype != save_dtype:
|
||||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
state_dict[key] = v
|
state_dict[key] = v
|
||||||
|
|
||||||
update_sd("", flux.state_dict())
|
update_sd("", flux.state_dict())
|
||||||
|
|
||||||
|
if not use_mem_eff_save:
|
||||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||||
|
else:
|
||||||
|
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||||
|
|
||||||
|
|
||||||
def save_flux_model_on_train_end(
|
def save_flux_model_on_train_end(
|
||||||
@@ -429,7 +438,7 @@ def save_flux_model_on_train_end(
|
|||||||
):
|
):
|
||||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||||
save_models(ckpt_file, flux, sai_metadata, save_dtype)
|
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||||
|
|
||||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||||
|
|
||||||
@@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise(
|
|||||||
):
|
):
|
||||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||||
save_models(ckpt_file, flux, sai_metadata, save_dtype)
|
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||||
|
|
||||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||||
args,
|
args,
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
from typing import *
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from typing import *
|
|
||||||
from diffusers import EulerAncestralDiscreteScheduler
|
from diffusers import EulerAncestralDiscreteScheduler
|
||||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||||
@@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False):
|
|||||||
logger.info(msg_init)
|
logger.info(msg_init)
|
||||||
|
|
||||||
|
|
||||||
|
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||||
|
"""
|
||||||
|
memory efficient save file
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TYPES = {
|
||||||
|
torch.float64: "F64",
|
||||||
|
torch.float32: "F32",
|
||||||
|
torch.float16: "F16",
|
||||||
|
torch.bfloat16: "BF16",
|
||||||
|
torch.int64: "I64",
|
||||||
|
torch.int32: "I32",
|
||||||
|
torch.int16: "I16",
|
||||||
|
torch.int8: "I8",
|
||||||
|
torch.uint8: "U8",
|
||||||
|
torch.bool: "BOOL",
|
||||||
|
getattr(torch, "float8_e5m2", None): "F8_E5M2",
|
||||||
|
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
|
||||||
|
}
|
||||||
|
_ALIGN = 256
|
||||||
|
|
||||||
|
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
validated = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise ValueError(f"Metadata key must be a string, got {type(key)}")
|
||||||
|
if not isinstance(value, str):
|
||||||
|
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
|
||||||
|
validated[key] = str(value)
|
||||||
|
else:
|
||||||
|
validated[key] = value
|
||||||
|
return validated
|
||||||
|
|
||||||
|
print(f"Using memory efficient save file: {filename}")
|
||||||
|
|
||||||
|
header = {}
|
||||||
|
offset = 0
|
||||||
|
if metadata:
|
||||||
|
header["__metadata__"] = validate_metadata(metadata)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
if v.numel() == 0: # empty tensor
|
||||||
|
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
|
||||||
|
else:
|
||||||
|
size = v.numel() * v.element_size()
|
||||||
|
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
|
||||||
|
offset += size
|
||||||
|
|
||||||
|
hjson = json.dumps(header).encode("utf-8")
|
||||||
|
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
|
||||||
|
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
f.write(struct.pack("<Q", len(hjson)))
|
||||||
|
f.write(hjson)
|
||||||
|
|
||||||
|
for k, v in tensors.items():
|
||||||
|
if v.numel() == 0:
|
||||||
|
continue
|
||||||
|
if v.is_cuda:
|
||||||
|
# Direct GPU to disk save
|
||||||
|
with torch.cuda.device(v.device):
|
||||||
|
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
tensor_bytes = v.contiguous().view(torch.uint8)
|
||||||
|
tensor_bytes.cpu().numpy().tofile(f)
|
||||||
|
else:
|
||||||
|
# CPU tensor save
|
||||||
|
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
||||||
|
|
||||||
|
|
||||||
# TODO make inf_utils.py
|
# TODO make inf_utils.py
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user