diff --git a/library/train_util.py b/library/train_util.py index fd46f905..fb8c0700 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3211,6 +3211,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", ) + parser.add_argument( + "--save_every_n_steps_after_x", type=int, default=0, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します" + ) parser.add_argument( "--save_n_epoch_ratio", type=int, diff --git a/sdxl_train.py b/sdxl_train.py index b533b274..605ce208 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -775,7 +775,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0e67cde5..85f591dc 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -500,7 +500,7 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 4a01f9e2..486816d4 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -460,7 +460,7 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_controlnet.py b/train_controlnet.py index 6938c4bc..f9df885c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -521,7 +521,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_db.py b/train_db.py index e7cf3cde..eeb0f642 100644 --- a/train_db.py +++ b/train_db.py @@ -399,7 +399,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/train_network.py b/train_network.py index 6953bb17..778e8dda 100644 --- a/train_network.py +++ b/train_network.py @@ -1037,7 +1037,8 @@ class NetworkTrainer: self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 37349da7..fbbeeef6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -646,7 +646,7 @@ class TextualInversionTrainer: ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs_list = [] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index fac0787b..c136ae96 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -515,7 +515,7 @@ def train(args): # ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs = (