From 5e817e4343cdf839096b430877609c6f52749a30 Mon Sep 17 00:00:00 2001 From: forestsource Date: Sun, 22 Jan 2023 02:57:12 +0900 Subject: [PATCH] Add save_n_epoch_ratio --- fine_tune.py | 2 ++ library/train_util.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/fine_tune.py b/fine_tune.py index 02f665bd..8e615203 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -200,6 +200,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/library/train_util.py b/library/train_util.py index aa65dc3c..5ff0280e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1028,6 +1028,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_n_epoch_ratio", type=int, default=None, + help="save checkpoint N epoch ratio / 学習中のモデルを指定のエポック割合で保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", diff --git a/train_db.py b/train_db.py index 8ac503ea..fe6fd4e6 100644 --- a/train_db.py +++ b/train_db.py @@ -176,6 +176,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/train_network.py b/train_network.py index b2c7b579..d3282da9 100644 --- a/train_network.py +++ b/train_network.py @@ -192,6 +192,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps