mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
added setting to save checkpoint only after X number of steps
if save_every_n_steps is set, and save_every_n_steps_after_x is set, then it will save only after the number of steps defined by save_every_n_steps_after_x
This commit is contained in:
@@ -2751,6 +2751,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_steps_after_x", type=int, default=None, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_n_epoch_ratio",
|
||||
type=int,
|
||||
|
||||
@@ -593,7 +593,9 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
|
||||
@@ -485,7 +485,9 @@ def train(args):
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -455,7 +455,9 @@ def train(args):
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -482,7 +482,9 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -361,7 +361,9 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
|
||||
@@ -828,7 +828,9 @@ class NetworkTrainer:
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -622,7 +622,9 @@ class TextualInversionTrainer:
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs_list = []
|
||||
|
||||
@@ -499,7 +499,9 @@ def train(args):
|
||||
# )
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \
|
||||
args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = (
|
||||
|
||||
Reference in New Issue
Block a user