mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge branch 'refactoring_training' of https://github.com/kohya-ss/sd-scripts into refactoring_training
This commit is contained in:
21
train_db.py
21
train_db.py
@@ -9,6 +9,7 @@ import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -816,16 +817,33 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
||||
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
|
||||
if args.save_last_n_epochs is not None:
|
||||
old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs))
|
||||
if os.path.exists(old_ckpt_file):
|
||||
os.remove(old_ckpt_file)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
||||
unwrap_model(unet), src_diffusers_model_path,
|
||||
use_safetensors=use_safetensors)
|
||||
if args.save_last_n_epochs is not None:
|
||||
out_dir_old = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs))
|
||||
if os.path.exists(out_dir_old):
|
||||
shutil.rmtree(out_dir_old)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))
|
||||
if args.save_last_n_epochs is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
@@ -888,6 +906,8 @@ if __name__ == '__main__':
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None,
|
||||
help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
@@ -942,3 +962,4 @@ if __name__ == '__main__':
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user