mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add save_every_n_steps option
This commit is contained in:
@@ -74,6 +74,11 @@ LAST_STATE_NAME = "{}-state"
|
||||
DEFAULT_EPOCH_NAME = "epoch"
|
||||
DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
DEFAULT_STEP_NAME = "at"
|
||||
STEP_STATE_NAME = "{}-step{:08d}-state"
|
||||
STEP_FILE_NAME = "{}-step{:08d}"
|
||||
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
|
||||
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
||||
@@ -1986,18 +1991,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_n_epoch_ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)",
|
||||
)
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument(
|
||||
"--save_last_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_epochs_state",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)",
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_last_n_steps_state",
|
||||
type=int,
|
||||
default=None,
|
||||
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_state",
|
||||
@@ -2903,26 +2928,53 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
||||
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
||||
return model_name, ckpt_name
|
||||
def default_if_none(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
if saving:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
save_func()
|
||||
|
||||
if args.save_last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
remove_old_func(remove_epoch_no)
|
||||
return saving
|
||||
def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end(
|
||||
def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
return STEP_FILE_NAME.format(model_name, step_no) + ext
|
||||
|
||||
|
||||
def get_last_ckpt_name(args: argparse.Namespace, ext: str):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
return model_name + ext
|
||||
|
||||
|
||||
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
|
||||
if args.save_last_n_epochs is None:
|
||||
return None
|
||||
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
if remove_epoch_no < 0:
|
||||
return None
|
||||
return remove_epoch_no
|
||||
|
||||
|
||||
def get_remove_step_no(args: argparse.Namespace, step_no: int):
|
||||
if args.save_last_n_steps is None:
|
||||
return None
|
||||
|
||||
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
|
||||
# save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
|
||||
remove_step_no = step_no - args.save_last_n_steps - 1
|
||||
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
|
||||
if remove_step_no < 0:
|
||||
return None
|
||||
return remove_step_no
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_sd_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
src_path: str,
|
||||
save_stable_diffusion_format: bool,
|
||||
@@ -2935,57 +2987,87 @@ def save_sd_model_on_epoch_end(
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
epoch_no = epoch + 1
|
||||
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
||||
if on_epoch_end:
|
||||
epoch_no = epoch + 1
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
if not saving:
|
||||
return
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
|
||||
def save_sd():
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||
|
||||
def remove_sd(old_epoch_no):
|
||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
save_func = save_sd
|
||||
remove_old_func = remove_sd
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
remove_no = get_remove_epoch_no(args, epoch_no)
|
||||
else:
|
||||
# 保存するか否かは呼び出し側で判断済み
|
||||
|
||||
def save_du():
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される
|
||||
remove_no = get_remove_step_no(args, global_step)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if save_stable_diffusion_format:
|
||||
ext = ".safetensors" if use_safetensors else ".ckpt"
|
||||
|
||||
if on_epoch_end:
|
||||
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
|
||||
else:
|
||||
ckpt_name = get_step_ckpt_name(args, ext, global_step)
|
||||
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||
|
||||
# remove older checkpoints
|
||||
if remove_no is not None:
|
||||
if on_epoch_end:
|
||||
remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no)
|
||||
else:
|
||||
remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no)
|
||||
|
||||
remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name)
|
||||
if os.path.exists(remove_ckpt_file):
|
||||
print(f"removing old checkpoint: {remove_ckpt_file}")
|
||||
os.remove(remove_ckpt_file)
|
||||
|
||||
else:
|
||||
if on_epoch_end:
|
||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
||||
print(f"saving model: {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||
|
||||
def remove_du(old_epoch_no):
|
||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||
if os.path.exists(out_dir_old):
|
||||
print(f"removing old model: {out_dir_old}")
|
||||
shutil.rmtree(out_dir_old)
|
||||
print(f"saving model: {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||
|
||||
save_func = save_du
|
||||
remove_old_func = remove_du
|
||||
# remove older checkpoints
|
||||
if remove_no is not None:
|
||||
if on_epoch_end:
|
||||
remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
|
||||
else:
|
||||
remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
|
||||
|
||||
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
||||
if os.path.exists(remove_out_dir):
|
||||
print(f"removing old model: {remove_out_dir}")
|
||||
shutil.rmtree(remove_out_dir)
|
||||
|
||||
if on_epoch_end:
|
||||
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
|
||||
else:
|
||||
save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||
print("saving state.")
|
||||
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
|
||||
print(f"saving state at epoch {epoch_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||
accelerator.save_state(state_dir)
|
||||
if args.save_state_to_huggingface:
|
||||
@@ -3001,12 +3083,40 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
|
||||
print(f"saving state at step {step_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
|
||||
accelerator.save_state(state_dir)
|
||||
if args.save_state_to_huggingface:
|
||||
print("uploading state to huggingface.")
|
||||
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
|
||||
|
||||
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
|
||||
if last_n_steps is not None:
|
||||
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
|
||||
remove_step_no = step_no - last_n_steps - 1
|
||||
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
|
||||
|
||||
if remove_step_no > 0:
|
||||
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
print(f"removing old state: {state_dir_old}")
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
|
||||
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||
accelerator.save_state(state_dir)
|
||||
|
||||
if args.save_state_to_huggingface:
|
||||
print("uploading last state to huggingface.")
|
||||
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
|
||||
@@ -3024,7 +3134,7 @@ def save_sd_model_on_train_end(
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user