From df9cb2f11c4706558fe2aaa329cf70758e1e03e6 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 15 Jan 2023 17:52:22 +0900 Subject: [PATCH 1/6] Add --save_last_n_epochs_model and --save_last_n_epochs_state --- library/train_util.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 57ebf1b0..bd59d831 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1029,6 +1029,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument("--save_last_n_epochs_model", type=int, default=None, help="save last N checkpoints / 最大Nエポックモデル保存する") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints / 最大Nエポックstate保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1303,10 +1305,11 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc os.makedirs(args.output_dir, exist_ok=True) save_func() - if args.save_last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_model else args.save_last_n_epochs + if last_n_epoch is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch remove_old_func(remove_epoch_no) - return saving, remove_epoch_no + return saving def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): @@ -1346,14 +1349,20 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: save_func = save_du remove_old_func = remove_du - saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) -def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + + remove_epoch_no = None + last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epoch is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch + if remove_epoch_no is not None: state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): From d30ea7966d5aaca379e0a00f3178fdcc62d8fe2c Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 15 Jan 2023 17:56:49 +0900 Subject: [PATCH 2/6] Updated help --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bd59d831..6170782b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1029,8 +1029,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_model", type=int, default=None, help="save last N checkpoints / 最大Nエポックモデル保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints / 最大Nエポックstate保存する") + parser.add_argument("--save_last_n_epochs_model", type=int, default=None, help="save last N checkpoints of model (overrides the value of --save_last_n_epochs) / 最大Nエポックモデルを保存する(--save_last_n_epochsの指定を上書きします)") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") From a8882238698f7b69e693ded173b31411f48c8034 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 15 Jan 2023 18:02:17 +0900 Subject: [PATCH 3/6] Fix a bug --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 6170782b..b9c4199b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1359,7 +1359,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) remove_epoch_no = None - last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_state else args.save_last_n_epochs + last_n_epoch = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epoch is not None: remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch From c6e28faa576701c6cd04e2abc2a356008f133997 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 15 Jan 2023 19:43:37 +0900 Subject: [PATCH 4/6] Save state when args.save_last_n_epochs_state is designated --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b9c4199b..63444f00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1350,7 +1350,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: remove_old_func = remove_du saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) - if saving and args.save_state: + if saving and args.save_state or args.save_last_n_epochs_state is not None: save_state_on_epoch_end(args, accelerator, model_name, epoch_no) From 3815b82bef06ff3015b11542bd672ad768378dc1 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Mon, 16 Jan 2023 21:02:27 +0900 Subject: [PATCH 5/6] Removed --save_last_n_epochs_model --- library/train_util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 63444f00..aee762d5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1029,7 +1029,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_model", type=int, default=None, help="save last N checkpoints of model (overrides the value of --save_last_n_epochs) / 最大Nエポックモデルを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") @@ -1305,9 +1304,8 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc os.makedirs(args.output_dir, exist_ok=True) save_func() - last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_model else args.save_last_n_epochs - if last_n_epoch is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs remove_old_func(remove_epoch_no) return saving From 3eb8fb187501352b54c4735b41c67128d3517ae3 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 18 Jan 2023 01:31:38 +0900 Subject: [PATCH 6/6] Make not to save state when args.save_state is False --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index aee762d5..3a7c2c8a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1348,7 +1348,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: remove_old_func = remove_du saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) - if saving and args.save_state or args.save_last_n_epochs_state is not None: + if saving and args.save_state: save_state_on_epoch_end(args, accelerator, model_name, epoch_no)