mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
unify dataset and save functions
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
# common functions for training
|
||||
# TODO test no_token_padding option
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from typing import NamedTuple
|
||||
from accelerate import Accelerator
|
||||
@@ -31,18 +33,16 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
# checkpointファイル名
|
||||
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
||||
LAST_STATE_NAME = "last-state"
|
||||
|
||||
EPOCH_FILE_NAME = "epoch-{:06d}"
|
||||
LAST_FILE_NAME = "last"
|
||||
|
||||
LAST_DIFFUSERS_DIR_NAME = "last"
|
||||
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
|
||||
|
||||
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
||||
EPOCH_FILE_NAME = "{}-{:06d}"
|
||||
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
|
||||
LAST_STATE_NAME = "{}-state"
|
||||
DEFAULT_EPOCH_NAME = "epoch"
|
||||
DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
# region dataset
|
||||
|
||||
|
||||
class ImageInfo():
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
@@ -76,6 +76,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.flip_aug = flip_aug
|
||||
self.color_aug = color_aug
|
||||
self.debug_dataset = debug_dataset
|
||||
self.padding_disabled = False
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -101,6 +102,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
|
||||
def disable_padding(self):
|
||||
self.padding_disabled = True
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
@@ -408,11 +412,18 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
caption = self.process_caption(image_info.caption)
|
||||
captions.append(caption)
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
if not self.padding_disabled: # this option might be omitted in future
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
|
||||
example = {}
|
||||
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||
example['input_ids'] = torch.stack(input_ids_list)
|
||||
|
||||
if self.padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example['input_ids'] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
@@ -664,6 +675,7 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
def debug_dataset(train_dataset):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
@@ -973,12 +985,13 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
parser.add_argument("--output_dir", type=str, default=None,
|
||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||
parser.add_argument("--output_name", type=str, default=None,
|
||||
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
@@ -1034,6 +1047,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
||||
parser.add_argument("--keep_tokens", type=int, default=None,
|
||||
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
||||
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
||||
@@ -1064,7 +1079,19 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_caption: bool):
|
||||
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
||||
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
args.caption_extention = None
|
||||
|
||||
if args.cache_latents:
|
||||
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
||||
|
||||
@@ -1083,7 +1110,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_caption: bool):
|
||||
else:
|
||||
args.face_crop_aug_range = None
|
||||
|
||||
if support_caption:
|
||||
if support_metadata:
|
||||
if args.in_json is not None and args.color_aug:
|
||||
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
|
||||
|
||||
@@ -1216,29 +1243,95 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func):
|
||||
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
||||
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
||||
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
||||
return model_name, ckpt_name
|
||||
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
remove_epoch_no = None
|
||||
if saving:
|
||||
print("saving checkpoint.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as)
|
||||
save_func(ckpt_file)
|
||||
save_func()
|
||||
|
||||
if args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
||||
if args.save_last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
remove_old_func(remove_epoch_no)
|
||||
return saving, remove_epoch_no
|
||||
|
||||
|
||||
def save_last_state(args, accelerator):
|
||||
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
||||
epoch_no = epoch + 1
|
||||
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
def save_sd():
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||
src_path, epoch_no, global_step, save_dtype, vae)
|
||||
|
||||
def remove_sd(old_epoch_no):
|
||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
save_func = save_sd
|
||||
remove_old_func = remove_sd
|
||||
else:
|
||||
def save_du():
|
||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
src_path, vae=vae, use_safetensors=use_safetensors)
|
||||
|
||||
def remove_du(old_epoch_no):
|
||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||
if os.path.exists(out_dir_old):
|
||||
shutil.rmtree(out_dir_old)
|
||||
|
||||
save_func = save_du
|
||||
remove_old_func = remove_du
|
||||
|
||||
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
|
||||
|
||||
def save_last_model(args, save_func):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as)
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_func(ckpt_file)
|
||||
print("model saved.")
|
||||
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||
src_path, epoch, global_step, save_dtype, vae)
|
||||
else:
|
||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||
|
||||
out_dir = os.path.join(args.output_dir, model_name)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
src_path, vae=vae, use_safetensors=use_safetensors)
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user