diff --git a/train_db.py b/train_db.py index 35ae1212..50ed5a64 100644 --- a/train_db.py +++ b/train_db.py @@ -9,6 +9,7 @@ import itertools import math import os import random +import shutil from tqdm import tqdm import torch @@ -816,16 +817,33 @@ def train(args): ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) + if args.save_last_n_epochs is not None: + old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) else: out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), src_diffusers_model_path, use_safetensors=use_safetensors) + if args.save_last_n_epochs is not None: + out_dir_old = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) + if os.path.exists(out_dir_old): + shutil.rmtree(out_dir_old) + + + + + if args.save_state: print("saving state.") accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) + if args.save_last_n_epochs is not None: + state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) is_main_process = accelerator.is_main_process if is_main_process: @@ -888,6 +906,8 @@ if __name__ == '__main__': help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, + help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -942,3 +962,4 @@ if __name__ == '__main__': args = parser.parse_args() train(args) +