mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
add save_last_n_epochs_state to train_network
This commit is contained in:
@@ -1299,7 +1299,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
|||||||
|
|
||||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||||
remove_epoch_no = None
|
|
||||||
if saving:
|
if saving:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
save_func()
|
save_func()
|
||||||
@@ -1356,12 +1355,9 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
print("saving state.")
|
print("saving state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||||
|
|
||||||
remove_epoch_no = None
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
last_n_epoch = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
if last_n_epochs is not None:
|
||||||
if last_n_epoch is not None:
|
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
||||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch
|
|
||||||
|
|
||||||
if remove_epoch_no is not None:
|
|
||||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||||
if os.path.exists(state_dir_old):
|
if os.path.exists(state_dir_old):
|
||||||
print(f"removing old state: {state_dir_old}")
|
print(f"removing old state: {state_dir_old}")
|
||||||
|
|||||||
@@ -367,9 +367,9 @@ def train(args):
|
|||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||||
if saving and args.save_state:
|
if saving and args.save_state:
|
||||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user