From 1890535d1b9af1fe9f525a5f34d3d652e482e3b6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Apr 2023 08:08:49 +0900 Subject: [PATCH] enable `cache_latents` when `_to_disk` #438 --- library/train_util.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ec17e11c..8c6e3437 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2185,6 +2185,12 @@ def verify_training_args(args: argparse.Namespace): if args.v2 and args.clip_skip is not None: print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + if args.cache_latents_to_disk and not args.cache_latents: + args.cache_latents = True + print( + "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -2963,7 +2969,7 @@ def get_remove_step_no(args: argparse.Namespace, step_no: int): # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する - remove_step_no = step_no - args.save_last_n_steps - 1 + remove_step_no = step_no - args.save_last_n_steps - 1 remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) if remove_step_no < 0: return None @@ -3005,7 +3011,7 @@ def save_sd_model_on_epoch_end_or_stepwise( os.makedirs(args.output_dir, exist_ok=True) if save_stable_diffusion_format: ext = ".safetensors" if use_safetensors else ".ckpt" - + if on_epoch_end: ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) else: