feat: reduce memory usage and add memory efficient option for model saving

This commit is contained in:
Kohya S
2024-08-19 22:30:24 +09:00
parent 6e72a799c8
commit 486fe8f70a
4 changed files with 100 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging
from .utils import setup_logging, mem_eff_save_file
setup_logging()
import logging
@@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
return model_pred, weighting
def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None):
def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("", flux.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_flux_model_on_train_end(
@@ -429,7 +438,7 @@ def save_flux_model_on_train_end(
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype)
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise(
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype)
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,