unify dataset and save functions

This commit is contained in:
Kohya S
2023-01-05 08:10:22 +09:00
parent 4c35006731
commit f56988b252
5 changed files with 287 additions and 1016 deletions

View File

@@ -1,7 +1,9 @@
# common functions for training
# TODO test no_token_padding option
import argparse
import json
import shutil
import time
from typing import NamedTuple
from accelerate import Accelerator
@@ -31,18 +33,16 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
# checkpointファイル名
EPOCH_STATE_NAME = "epoch-{:06d}-state"
LAST_STATE_NAME = "last-state"
EPOCH_FILE_NAME = "epoch-{:06d}"
LAST_FILE_NAME = "last"
LAST_DIFFUSERS_DIR_NAME = "last"
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
EPOCH_STATE_NAME = "{}-{:06d}-state"
EPOCH_FILE_NAME = "{}-{:06d}"
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
LAST_STATE_NAME = "{}-state"
DEFAULT_EPOCH_NAME = "epoch"
DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
class ImageInfo():
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
@@ -76,6 +76,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.padding_disabled = False
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -101,6 +102,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_data: dict[str, ImageInfo] = {}
def disable_padding(self):
self.padding_disabled = True
def process_caption(self, caption):
if self.shuffle_caption:
tokens = caption.strip().split(",")
@@ -408,11 +412,18 @@ class BaseDataset(torch.utils.data.Dataset):
caption = self.process_caption(image_info.caption)
captions.append(caption)
input_ids_list.append(self.get_input_ids(caption))
if not self.padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
example = {}
example['loss_weights'] = torch.FloatTensor(loss_weights)
example['input_ids'] = torch.stack(input_ids_list)
if self.padding_disabled:
# padding=True means pad in the batch
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example['input_ids'] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
@@ -664,6 +675,7 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
def debug_dataset(train_dataset):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
@@ -973,12 +985,13 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None,
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@@ -1034,6 +1047,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子スペルミスを残してあります")
parser.add_argument("--keep_tokens", type=int, default=None,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
@@ -1064,7 +1079,19 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
def prepare_dataset_args(args: argparse.Namespace, support_caption: bool):
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
args.caption_extention = None
if args.cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
@@ -1083,7 +1110,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_caption: bool):
else:
args.face_crop_aug_range = None
if support_caption:
if support_metadata:
if args.in_json is not None and args.color_aug:
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
@@ -1216,29 +1243,95 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
return encoder_hidden_states
def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func):
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
return model_name, ckpt_name
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
remove_epoch_no = None
if saving:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as)
save_func(ckpt_file)
save_func()
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
if args.save_last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
remove_old_func(remove_epoch_no)
return saving, remove_epoch_no
def save_last_state(args, accelerator):
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
epoch_no = epoch + 1
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
if save_stable_diffusion_format:
def save_sd():
ckpt_file = os.path.join(args.output_dir, ckpt_name)
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
src_path, epoch_no, global_step, save_dtype, vae)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
os.remove(old_ckpt_file)
save_func = save_sd
remove_old_func = remove_sd
else:
def save_du():
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
src_path, vae=vae, use_safetensors=use_safetensors)
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
if os.path.exists(out_dir_old):
shutil.rmtree(out_dir_old)
save_func = save_du
remove_old_func = remove_du
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
if saving and args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
if remove_epoch_no is not None:
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
def save_last_model(args, save_func):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as)
print(f"save trained model to {ckpt_file}")
save_func(ckpt_file)
print("model saved.")
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
if save_stable_diffusion_format:
os.makedirs(args.output_dir, exist_ok=True)
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
src_path, epoch, global_step, save_dtype, vae)
else:
print(f"save trained model as Diffusers to {args.output_dir}")
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
src_path, vae=vae, use_safetensors=use_safetensors)
# endregion