Removed --save_last_n_epochs_model

This commit is contained in:
Yuta Hayashibe
2023-01-16 21:02:27 +09:00
parent c6e28faa57
commit 3815b82bef

View File

@@ -1029,7 +1029,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_model", type=int, default=None, help="save last N checkpoints of model (overrides the value of --save_last_n_epochs) / 最大Nエポックモデルを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
@@ -1305,9 +1304,8 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
os.makedirs(args.output_dir, exist_ok=True)
save_func()
last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_model else args.save_last_n_epochs
if last_n_epoch is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch
if args.save_last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
remove_old_func(remove_epoch_no)
return saving