mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
feat: reduce memory usage and add memory efficient option for model saving
This commit is contained in:
@@ -20,7 +20,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging
|
||||
from .utils import setup_logging, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
return model_pred, weighting
|
||||
|
||||
|
||||
def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None):
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
flux: flux_models.Flux,
|
||||
sai_metadata: Optional[dict],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
use_mem_eff_save: bool = False,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
if save_dtype is not None:
|
||||
if save_dtype is not None and v.dtype != save_dtype:
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
update_sd("", flux.state_dict())
|
||||
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
if not use_mem_eff_save:
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
else:
|
||||
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
|
||||
|
||||
def save_flux_model_on_train_end(
|
||||
@@ -429,7 +438,7 @@ def save_flux_model_on_train_end(
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype)
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
@@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise(
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype)
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
|
||||
Reference in New Issue
Block a user