mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
unify dataset and save functions
This commit is contained in:
300
fine_tune.py
300
fine_tune.py
@@ -1,27 +1,17 @@
|
||||
# training with captions
|
||||
# XXX dropped option: fine_tune
|
||||
# XXX dropped option: hypernetwork training
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
import importlib
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from transformers import CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
|
||||
|
||||
@@ -29,211 +19,21 @@ def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
class FineTuningDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, dataset_repeats, debug) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.metadata = metadata
|
||||
self.train_data_dir = train_data_dir
|
||||
self.batch_size = batch_size
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_token_length = max_token_length
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.shuffle_keep_tokens = shuffle_keep_tokens
|
||||
self.debug = debug
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
print("make buckets")
|
||||
|
||||
# 最初に数を数える
|
||||
self.bucket_resos = set()
|
||||
for img_md in metadata.values():
|
||||
if 'train_resolution' in img_md:
|
||||
self.bucket_resos.add(tuple(img_md['train_resolution']))
|
||||
self.bucket_resos = list(self.bucket_resos)
|
||||
self.bucket_resos.sort()
|
||||
print(f"number of buckets: {len(self.bucket_resos)}")
|
||||
|
||||
reso_to_index = {}
|
||||
for i, reso in enumerate(self.bucket_resos):
|
||||
reso_to_index[reso] = i
|
||||
|
||||
# bucketに割り当てていく
|
||||
self.buckets = [[] for _ in range(len(self.bucket_resos))]
|
||||
n = 1 if dataset_repeats is None else dataset_repeats
|
||||
images_count = 0
|
||||
for image_key, img_md in metadata.items():
|
||||
if 'train_resolution' not in img_md:
|
||||
continue
|
||||
if not os.path.exists(self.image_key_to_npz_file(image_key)):
|
||||
continue
|
||||
|
||||
reso = tuple(img_md['train_resolution'])
|
||||
for _ in range(n):
|
||||
self.buckets[reso_to_index[reso]].append(image_key)
|
||||
images_count += n
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices = []
|
||||
for bucket_index, bucket in enumerate(self.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append((bucket_index, batch_index))
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
self.images_count = images_count
|
||||
|
||||
def show_buckets(self):
|
||||
for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)):
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
def shuffle_buckets(self):
|
||||
random.shuffle(self.buckets_indices)
|
||||
for bucket in self.buckets:
|
||||
random.shuffle(bucket)
|
||||
|
||||
def image_key_to_npz_file(self, image_key):
|
||||
npz_file_norm = os.path.splitext(image_key)[0] + '.npz'
|
||||
if os.path.exists(npz_file_norm):
|
||||
if random.random() < .5:
|
||||
npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz'
|
||||
if os.path.exists(npz_file_flip):
|
||||
return npz_file_flip
|
||||
return npz_file_norm
|
||||
|
||||
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
||||
if random.random() < .5:
|
||||
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
||||
if os.path.exists(npz_file_flip):
|
||||
return npz_file_flip
|
||||
return npz_file_norm
|
||||
|
||||
def load_latent(self, image_key):
|
||||
return np.load(self.image_key_to_npz_file(image_key))['arr_0']
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index == 0:
|
||||
self.shuffle_buckets()
|
||||
|
||||
bucket = self.buckets[self.buckets_indices[index][0]]
|
||||
image_index = self.buckets_indices[index][1] * self.batch_size
|
||||
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
captions = []
|
||||
for image_key in bucket[image_index:image_index + self.batch_size]:
|
||||
img_md = self.metadata[image_key]
|
||||
caption = img_md.get('caption')
|
||||
tags = img_md.get('tags')
|
||||
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}"
|
||||
|
||||
latents = self.load_latent(image_key)
|
||||
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
if self.shuffle_keep_tokens is None:
|
||||
random.shuffle(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
||||
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
||||
|
||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||
ids_chunk = (input_ids[0].unsqueeze(0),
|
||||
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0))
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
||||
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = self.tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
||||
ids_chunk[1] = self.tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
latents_list.append(torch.FloatTensor(latents))
|
||||
|
||||
example = {}
|
||||
example['input_ids'] = torch.stack(input_ids_list)
|
||||
example['latents'] = torch.stack(latents_list)
|
||||
if self.debug:
|
||||
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
||||
example['captions'] = captions
|
||||
return example
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
# verify load/save model formats
|
||||
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
||||
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
|
||||
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
@@ -253,6 +53,21 @@ def train(args):
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# Diffusers版のxformers使用フラグを設定する関数
|
||||
def set_diffusers_xformers_flag(model, valid):
|
||||
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
||||
@@ -308,7 +123,11 @@ def train(args):
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||
text_encoder.eval()
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
text_encoder.train() # required for gradient_checkpointing
|
||||
else:
|
||||
text_encoder.eval()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
@@ -365,12 +184,7 @@ def train(args):
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
|
||||
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
@@ -413,7 +227,6 @@ def train(args):
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# with torch.no_grad():
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
@@ -435,7 +248,6 @@ def train(args):
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
# Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
@@ -478,63 +290,26 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
def save_func(file):
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
|
||||
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
|
||||
train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func)
|
||||
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
||||
print("saving checkpoint.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
|
||||
|
||||
if save_stable_diffusion_format:
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
||||
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet),
|
||||
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
|
||||
if args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1)))
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
if fine_tuning:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
else:
|
||||
hypernetwork = unwrap_model(hypernetwork)
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
print("saving last state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME))
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors))
|
||||
|
||||
if fine_tuning:
|
||||
if save_stable_diffusion_format:
|
||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||
src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae)
|
||||
else:
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||
out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors)
|
||||
else:
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_hypernetwork(ckpt_file, hypernetwork)
|
||||
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, global_step, text_encoder, unet, vae)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
@@ -544,9 +319,8 @@ if __name__ == '__main__':
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user