added setting to save checkpoint only after X number of steps

if save_every_n_steps is set, and save_every_n_steps_after_x is set, then it will save only after the number of steps defined by save_every_n_steps_after_x
This commit is contained in:
yushan777
2023-09-24 12:22:36 +01:00
parent 1e395ed285
commit 96f06d917e
9 changed files with 27 additions and 8 deletions

View File

@@ -2751,6 +2751,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
)
parser.add_argument(
"--save_every_n_steps_after_x", type=int, default=None, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します"
)
parser.add_argument(
"--save_n_epoch_ratio",
type=int,

View File

@@ -593,7 +593,9 @@ def train(args):
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path

View File

@@ -485,7 +485,9 @@ def train(args):
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)

View File

@@ -455,7 +455,9 @@ def train(args):
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)

View File

@@ -482,7 +482,9 @@ def train(args):
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)

View File

@@ -361,7 +361,9 @@ def train(args):
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path

View File

@@ -828,7 +828,9 @@ class NetworkTrainer:
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)

View File

@@ -622,7 +622,9 @@ class TextualInversionTrainer:
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs_list = []

View File

@@ -499,7 +499,9 @@ def train(args):
# )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs = (