From 61a61c51ee59184f24440a629afb2dcb760f0136 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 1 Jan 2023 21:46:38 +0900 Subject: [PATCH 1/2] Add --save_last_n_epochs option --- train_db.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/train_db.py b/train_db.py index 1dde882c..5a1fb721 100644 --- a/train_db.py +++ b/train_db.py @@ -27,6 +27,7 @@ import itertools import math import os import random +import shutil from tqdm import tqdm import torch @@ -1101,16 +1102,28 @@ def train(args): ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) + if args.save_last_n_epochs is not None: + old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) else: out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), src_diffusers_model_path, use_safetensors=use_safetensors) + if args.save_last_n_epochs is not None: + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(out_dir_old): + shutil.rmtree(out_dir_old) if args.save_state: print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + if args.save_last_n_epochs is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) is_main_process = accelerator.is_main_process if is_main_process: @@ -1173,6 +1186,8 @@ if __name__ == '__main__': help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, + help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") From 85d8b4912955228ab57f194b0eba951a338849b0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 1 Jan 2023 23:36:20 +0900 Subject: [PATCH 2/2] Fix calculation for the old epoch --- train_db.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_db.py b/train_db.py index 5a1fb721..a0e2357f 100644 --- a/train_db.py +++ b/train_db.py @@ -1103,7 +1103,7 @@ def train(args): model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) if args.save_last_n_epochs is not None: - old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs)) + old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(old_ckpt_file): os.remove(old_ckpt_file) else: @@ -1113,7 +1113,7 @@ def train(args): unwrap_model(unet), src_diffusers_model_path, use_safetensors=use_safetensors) if args.save_last_n_epochs is not None: - out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs)) + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(out_dir_old): shutil.rmtree(out_dir_old) @@ -1121,7 +1121,7 @@ def train(args): print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) if args.save_last_n_epochs is not None: - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs)) + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(state_dir_old): shutil.rmtree(state_dir_old)