mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
@@ -1029,6 +1029,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||||
|
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
@@ -1306,7 +1307,7 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
|
|||||||
if args.save_last_n_epochs is not None:
|
if args.save_last_n_epochs is not None:
|
||||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||||
remove_old_func(remove_epoch_no)
|
remove_old_func(remove_epoch_no)
|
||||||
return saving, remove_epoch_no
|
return saving
|
||||||
|
|
||||||
|
|
||||||
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
||||||
@@ -1346,14 +1347,20 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
|||||||
save_func = save_du
|
save_func = save_du
|
||||||
remove_old_func = remove_du
|
remove_old_func = remove_du
|
||||||
|
|
||||||
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||||
if saving and args.save_state:
|
if saving and args.save_state:
|
||||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no)
|
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
||||||
|
|
||||||
|
|
||||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no):
|
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||||
|
|
||||||
|
remove_epoch_no = None
|
||||||
|
last_n_epoch = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
|
if last_n_epoch is not None:
|
||||||
|
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch
|
||||||
|
|
||||||
if remove_epoch_no is not None:
|
if remove_epoch_no is not None:
|
||||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||||
if os.path.exists(state_dir_old):
|
if os.path.exists(state_dir_old):
|
||||||
|
|||||||
Reference in New Issue
Block a user