enable cache_latents when _to_disk #438

This commit is contained in:
Kohya S
2023-04-25 08:08:49 +09:00
parent 9bb52acc14
commit 1890535d1b

View File

@@ -2185,6 +2185,12 @@ def verify_training_args(args: argparse.Namespace):
if args.v2 and args.clip_skip is not None: if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
print(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
def add_dataset_arguments( def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@@ -2963,7 +2969,7 @@ def get_remove_step_no(args: argparse.Namespace, step_no: int):
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
# save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
remove_step_no = step_no - args.save_last_n_steps - 1 remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no < 0: if remove_step_no < 0:
return None return None
@@ -3005,7 +3011,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if save_stable_diffusion_format: if save_stable_diffusion_format:
ext = ".safetensors" if use_safetensors else ".ckpt" ext = ".safetensors" if use_safetensors else ".ckpt"
if on_epoch_end: if on_epoch_end:
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
else: else: