From 96f06d917e8666cc6d7f882ebcd4fbf051765812 Mon Sep 17 00:00:00 2001 From: yushan777 Date: Sun, 24 Sep 2023 12:22:36 +0100 Subject: [PATCH] added setting to save checkpoint only after X number of steps if save_every_n_steps is set, and save_every_n_steps_after_x is set, then it will save only after the number of steps defined by save_every_n_steps_after_x --- library/train_util.py | 3 +++ sdxl_train.py | 4 +++- sdxl_train_control_net_lllite.py | 4 +++- sdxl_train_control_net_lllite_old.py | 4 +++- train_controlnet.py | 4 +++- train_db.py | 4 +++- train_network.py | 4 +++- train_textual_inversion.py | 4 +++- train_textual_inversion_XTI.py | 4 +++- 9 files changed, 27 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 35bfb5f5..c30b7c42 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2751,6 +2751,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" ) + parser.add_argument( + "--save_every_n_steps_after_x", type=int, default=None, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します" + ) parser.add_argument( "--save_n_epoch_ratio", type=int, diff --git a/sdxl_train.py b/sdxl_train.py index 6b255d67..6a476a1f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -593,7 +593,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 61ebfb58..9cf41c81 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -485,7 +485,9 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f8169bdb..98a9818c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -455,7 +455,9 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_controlnet.py b/train_controlnet.py index 42da4412..83aae7cf 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -482,7 +482,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_db.py b/train_db.py index feb14778..831bc445 100644 --- a/train_db.py +++ b/train_db.py @@ -361,7 +361,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/train_network.py b/train_network.py index 0e2e0fa9..0f566353 100644 --- a/train_network.py +++ b/train_network.py @@ -828,7 +828,9 @@ class NetworkTrainer: self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1c7b7fcb..9d79fad2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -622,7 +622,9 @@ class TextualInversionTrainer: ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs_list = [] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2c5673be..2ab48b63 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -499,7 +499,9 @@ def train(args): # ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs = (