From 095b8035e63f7c79a232114d8f0e1ec27f431ebc Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 23:33:38 +0800 Subject: [PATCH 1/2] save state on train end --- fine_tune.py | 2 +- library/train_util.py | 5 +++++ sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 13 insertions(+), 8 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 875a9195..46f12828 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -457,7 +457,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index d2b69edb..b3ca15f5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2890,6 +2890,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する", + ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d..107bb945 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,7 @@ def train(args): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1e5f9234..e99b4e35 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -549,7 +549,7 @@ def train(args): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_controlnet.py b/train_controlnet.py index dc73a91c..e44f0885 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -565,7 +565,7 @@ def train(args): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく diff --git a/train_db.py b/train_db.py index 8d36097a..41a9a7b9 100644 --- a/train_db.py +++ b/train_db.py @@ -444,7 +444,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/train_network.py b/train_network.py index e0fa6945..4707d5ae 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ class NetworkTrainer: accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df1d8485..0266bc14 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -732,7 +732,7 @@ class TextualInversionTrainer: accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 695fad2a..ad7c267e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -586,7 +586,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() From d282c450026dcfd5f1fd5856f5087ebaed47be46 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:56:09 +0800 Subject: [PATCH 2/2] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4707d5ae..3db583f1 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ class NetworkTrainer: accelerator.end_training() - if is_main_process and args.save_state or args.save_state_on_train_end: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: