From 96f06d917e8666cc6d7f882ebcd4fbf051765812 Mon Sep 17 00:00:00 2001 From: yushan777 Date: Sun, 24 Sep 2023 12:22:36 +0100 Subject: [PATCH 1/3] added setting to save checkpoint only after X number of steps if save_every_n_steps is set, and save_every_n_steps_after_x is set, then it will save only after the number of steps defined by save_every_n_steps_after_x --- library/train_util.py | 3 +++ sdxl_train.py | 4 +++- sdxl_train_control_net_lllite.py | 4 +++- sdxl_train_control_net_lllite_old.py | 4 +++- train_controlnet.py | 4 +++- train_db.py | 4 +++- train_network.py | 4 +++- train_textual_inversion.py | 4 +++- train_textual_inversion_XTI.py | 4 +++- 9 files changed, 27 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 35bfb5f5..c30b7c42 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2751,6 +2751,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" ) + parser.add_argument( + "--save_every_n_steps_after_x", type=int, default=None, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します" + ) parser.add_argument( "--save_n_epoch_ratio", type=int, diff --git a/sdxl_train.py b/sdxl_train.py index 6b255d67..6a476a1f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -593,7 +593,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 61ebfb58..9cf41c81 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -485,7 +485,9 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f8169bdb..98a9818c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -455,7 +455,9 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_controlnet.py b/train_controlnet.py index 42da4412..83aae7cf 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -482,7 +482,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_db.py b/train_db.py index feb14778..831bc445 100644 --- a/train_db.py +++ b/train_db.py @@ -361,7 +361,9 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/train_network.py b/train_network.py index 0e2e0fa9..0f566353 100644 --- a/train_network.py +++ b/train_network.py @@ -828,7 +828,9 @@ class NetworkTrainer: self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1c7b7fcb..9d79fad2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -622,7 +622,9 @@ class TextualInversionTrainer: ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs_list = [] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2c5673be..2ab48b63 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -499,7 +499,9 @@ def train(args): # ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ + args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs = ( From a2c0f3644b5d9e1f60d9720419bd3aea510684c2 Mon Sep 17 00:00:00 2001 From: yushan777 Date: Sun, 24 Sep 2023 14:07:37 +0100 Subject: [PATCH 2/3] Update train_util.py default = 0 --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index c30b7c42..0ffbaf00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2752,7 +2752,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" ) parser.add_argument( - "--save_every_n_steps_after_x", type=int, default=None, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します" + "--save_every_n_steps_after_x", type=int, default=0, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します" ) parser.add_argument( "--save_n_epoch_ratio", From 51e1b45abd9da75d6509fd72a7760d2de0ce1396 Mon Sep 17 00:00:00 2001 From: yushan777 Date: Sun, 24 Sep 2023 15:24:29 +0100 Subject: [PATCH 3/3] update --- sdxl_train.py | 4 +--- sdxl_train_control_net_lllite.py | 4 +--- sdxl_train_control_net_lllite_old.py | 4 +--- train_controlnet.py | 4 +--- train_db.py | 4 +--- train_network.py | 3 +-- train_textual_inversion.py | 4 +--- train_textual_inversion_XTI.py | 4 +--- 8 files changed, 8 insertions(+), 23 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 6a476a1f..724bf558 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -593,9 +593,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9cf41c81..99abb76c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -485,9 +485,7 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 98a9818c..fde6025c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -455,9 +455,7 @@ def train(args): # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_controlnet.py b/train_controlnet.py index 83aae7cf..6444b654 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -482,9 +482,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) diff --git a/train_db.py b/train_db.py index 831bc445..540a2dfd 100644 --- a/train_db.py +++ b/train_db.py @@ -361,9 +361,7 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path diff --git a/train_network.py b/train_network.py index 0f566353..7c5c5e7e 100644 --- a/train_network.py +++ b/train_network.py @@ -828,8 +828,7 @@ class NetworkTrainer: self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9d79fad2..c09d0612 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -622,9 +622,7 @@ class TextualInversionTrainer: ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs_list = [] diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2ab48b63..1433c71d 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -499,9 +499,7 @@ def train(args): # ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and \ - args.save_every_n_steps_after_x is not None and global_step >= args.save_every_n_steps_after_x: - + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x: accelerator.wait_for_everyone() if accelerator.is_main_process: updated_embs = (