add sdxl fine-tuning and LoRA

This commit is contained in:
Kohya S
2023-06-26 08:07:24 +09:00
parent 9e9df2b501
commit 747af145ed
11 changed files with 2442 additions and 754 deletions

View File

@@ -1061,6 +1061,16 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
return text_model, vae, unet
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
# only for reference
version_str = "sd"
if v2:
version_str += "_v2"
else:
version_str += "_v1"
if v_parameterization:
version_str += "_v"
return version_str
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
def convert_key(key):

View File

@@ -6,6 +6,10 @@ from library import model_util
from library import sdxl_original_unet
VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
@@ -76,8 +80,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
return new_sd, text_projection, logit_scale
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
# model_type is reserved to future use
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
# model_version is reserved for future use
# Load the state dict
if model_util.is_safetensors(ckpt_path):

View File

@@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module):
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert y.shape[0] == x.shape[0]
assert x.dtype == y.dtype
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == self.dtype
emb = emb + self.label_emb(y)
@@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module):
if __name__ == "__main__":
import time
print("create unet")
unet = SdxlUNet2DConditionModel()
@@ -1132,8 +1134,11 @@ if __name__ == "__main__":
print("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
@@ -1149,3 +1154,6 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")

384
library/sdxl_train_util.py Normal file
View File

@@ -0,0 +1,384 @@
import argparse
import gc
import math
import os
from types import SimpleNamespace
from typing import Any
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
from library import model_util, sdxl_model_util, train_util
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
DEFAULT_NOISE_OFFSET = 0.0357
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
text_projection,
logit_scale,
ckpt_info,
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
# work on low-ram device
if args.lowram:
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
# only supports StableDiffusion
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
assert (
load_stable_diffusion_format
), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}"
print(f"load StableDiffusion checkpoint: {name_or_path}")
(
text_encoder1,
text_encoder2,
vae,
unet,
text_projection,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
print("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
class WrapperTokenizer:
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
def __init__(self):
open_clip_tokenizer = open_clip.tokenizer._tokenizer
self.model_max_length = 77
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
self.pad_token_id = 0 # 結果から推定している
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds)
def tokenize(self, text, padding, truncation, max_length, return_tensors):
assert padding == "max_length"
assert truncation == True
assert return_tensors == "pt"
input_ids = open_clip.tokenize(text, context_length=max_length)
return SimpleNamespace(**{"input_ids": input_ids})
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
original_path = TOKENIZER_PATH
tokenizer1: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer1 is None:
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer1.save_pretrained(local_tokenizer_path)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
# tokenizer2 is from open_clip
# TODO caching
tokenizer2 = WrapperTokenizer()
return [tokenizer1, tokenizer2]
def get_hidden_states(
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
):
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["pooler_output"]
if args.max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
hidden_states1 = hidden_states1.to(weight_dtype)
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states1, hidden_states2, pool2
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_timestep_embedding(x, outdim):
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
def get_size_embeddings(orig_size, crop_size, target_size, device):
emb1 = get_timestep_embedding(orig_size, 256)
emb2 = get_timestep_embedding(crop_size, 256)
emb3 = get_timestep_embedding(target_size, 256)
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
return vector
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
text_encoder1,
text_encoder2,
unet,
vae,
text_projection,
logit_scale,
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
text_encoder2,
unet,
epoch_no,
global_step,
ckpt_info,
vae,
text_projection,
logit_scale,
save_dtype,
)
def diffusers_saver(out_dir):
raise NotImplementedError("diffusers_saver is not implemented")
train_util.save_sd_model_on_train_end_common(
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
src_path,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
text_encoder1,
text_encoder2,
unet,
vae,
text_projection,
logit_scale,
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
text_encoder2,
unet,
epoch_no,
global_step,
ckpt_info,
vae,
text_projection,
logit_scale,
save_dtype,
)
def diffusers_saver(out_dir):
raise NotImplementedError("diffusers_saver is not implemented")
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)
# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
print("caching text encoder outputs")
tokenizer1, tokenizer2 = tokenizers
text_encoder1, text_encoder2 = text_encoders
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if weight_dtype is not None:
text_encoder1.to(dtype=weight_dtype)
text_encoder2.to(dtype=weight_dtype)
text_encoder1_cache = {}
text_encoder2_cache = {}
for batch in tqdm(data_loader):
input_ids1_batch = batch["input_ids"]
input_ids2_batch = batch["input_ids2"]
# split batch to avoid OOM
# TODO specify batch size by args
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache
input_ids1 = input_ids1.squeeze(0)
input_ids2 = input_ids2.squeeze(0)
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
assert len(input_ids1) == len(input_ids2)
if len(input_ids1) == 0:
continue
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
pool2 = pool2.to("cpu")
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
):
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
return text_encoder1_cache, text_encoder2_cache
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
def verify_sdxl_training_args(args: argparse.Namespace):
assert (
not args.v2 and not args.v_parameterization
), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません"
if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
if args.multires_noise_iterations:
print(
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
)
else:
if args.noise_offset is None:
args.noise_offset = DEFAULT_NOISE_OFFSET
elif args.noise_offset != DEFAULT_NOISE_OFFSET:
print(
f"Waring: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
)
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"

View File

@@ -798,6 +798,19 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self):
return all(
[
not (
subset.caption_dropout_rate > 0
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
)
for subset in self.subsets
]
)
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
@@ -850,7 +863,7 @@ class BaseDataset(torch.utils.data.Dataset):
continue
cache_available = self.is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
)
if cache_available: # do not add to batch
@@ -1719,6 +1732,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
@@ -3284,11 +3300,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
return text_encoder, vae, unet, load_stable_diffusion_format
# TODO remove this function in the future
def transform_if_model_is_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def transform_models_if_DDP(models):
# Transform text_encoder, unet and network from DistributedDataParallel
return [model.module if type(model) == DDP else model for model in models if model is not None]
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
@@ -3430,6 +3452,42 @@ def save_sd_model_on_epoch_end_or_stepwise(
text_encoder,
unet,
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
def diffusers_saver(out_dir):
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)
def save_sd_model_on_epoch_end_or_stepwise_common(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_stable_diffusion_format: bool,
use_safetensors: bool,
epoch: int,
num_train_epochs: int,
global_step: int,
sd_saver,
diffusers_saver,
):
if on_epoch_end:
epoch_no = epoch + 1
@@ -3457,9 +3515,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
sd_saver(ckpt_file, epoch_no, global_step)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
@@ -3483,9 +3539,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
print(f"\nsaving model: {out_dir}")
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
diffusers_saver(out_dir)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
@@ -3578,6 +3633,30 @@ def save_sd_model_on_train_end(
text_encoder,
unet,
vae,
):
def sd_saver(ckpt_file, epoch_no, global_step):
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
def diffusers_saver(out_dir):
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
save_sd_model_on_train_end_common(
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
)
def save_sd_model_on_train_end_common(
args: argparse.Namespace,
save_stable_diffusion_format: bool,
use_safetensors: bool,
epoch: int,
global_step: int,
sd_saver,
diffusers_saver,
):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
@@ -3588,9 +3667,8 @@ def save_sd_model_on_train_end(
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
sd_saver(ckpt_file, epoch, global_step)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
@@ -3598,9 +3676,8 @@ def save_sd_model_on_train_end(
os.makedirs(out_dir, exist_ok=True)
print(f"save trained model as Diffusers to {out_dir}")
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
diffusers_saver(out_dir)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)