support textual inversion training

This commit is contained in:
Kohya S
2023-07-10 22:04:02 +09:00
parent b6e328ea8f
commit f54b784d88
5 changed files with 787 additions and 446 deletions

View File

@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
Summary of the feature:
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
- `--weighted_captions` option is not supported yet.
- `--weighted_captions` option is not supported yet for both scripts.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
`requirements.txt` is updated to support SDXL training.
@@ -54,7 +67,7 @@ Summary of the feature:
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### TODO
- [ ] Support Textual Inversion training.
- [ ] Support conversion of Diffusers SDXL models.
- [ ] Support `--weighted_captions` option.
- [ ] Change `--output_config` option to continue the training.
- [ ] Extend `--full_bf16` for all the scripts.
- [x] Support Textual Inversion training.
## About requirements.txt

View File

@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
class WrapperTokenizer:
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
# make open clip tokenizer compatible with HuggingFace tokenizer
def __init__(self):
open_clip_tokenizer = open_clip.tokenizer._tokenizer
self.model_max_length = 77
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
self.pad_token_id = 0 # 結果から推定している
self.pad_token_id = 0 # 結果から推定している assumption from result
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds)
@@ -107,6 +108,42 @@ class WrapperTokenizer:
input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})
# for Textual Inversion
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
def encode(self, text, add_special_tokens=False):
assert not add_special_tokens
input_ids = open_clip.tokenizer._tokenizer.encode(text)
return input_ids
def add_tokens(self, new_tokens):
tokens_to_add = []
for token in new_tokens:
token = token.lower()
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
tokens_to_add.append(token)
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
for token in tokens_to_add:
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
open_clip.tokenizer._tokenizer.vocab_size += 1
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
# this is very rough, so it will not work if the specification of open clip tokenizer changes
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
return len(tokens_to_add)
def convert_tokens_to_ids(self, tokens):
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
return input_ids
def __len__(self):
return open_clip.tokenizer._tokenizer.vocab_size
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not args.weighted_captions
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"

View File

@@ -320,7 +320,7 @@ class PipelineLike:
self.scheduler = scheduler
self.safety_checker = None
# Textual Inversion # not tested yet
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({})
@@ -341,6 +341,10 @@ class PipelineLike:
token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
new_tokens = []
for token in tokens:
if token in token_replacements:
@@ -1594,19 +1598,26 @@ def main(args):
if "string_to_param" in data:
data = data["string_to_param"]
embeds1 = data["clip_l"]
embeds2 = data["clip_g"]
embeds1 = data["clip_l"] # text encoder 1
embeds2 = data["clip_g"] # text encoder 2
num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
assert (
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
)
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
@@ -1617,11 +1628,11 @@ def main(args):
assert (
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
), f"token ids2 is not ordered"
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
token_ids_embeds1.append((token_ids1, embeds1))

View File

@@ -0,0 +1,142 @@
import argparse
import os
import regex
import torch
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
# tokenizer 1 is seems to be ok
# count words for token string: regular expression from open_clip
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
words = regex.findall(pat, token_string)
word_count = len(words)
assert word_count == 1, (
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
args,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype):
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file)
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
data = torch.load(file, map_location="cpu")
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
assert (
emb_l is not None or emb_g is not None
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
return [emb_l, emb_g]
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()
trainer.train(args)

View File

@@ -8,6 +8,7 @@ from tqdm import tqdm
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -20,8 +21,6 @@ import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
@@ -78,7 +77,81 @@ imagenet_style_templates_small = [
]
def train(args):
class TextualInversionTrainer:
def __init__(self):
self.vae_scale_factor = 0.18215
def assert_extra_args(self, args, train_dataset_group):
pass
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
with torch.enable_grad():
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs[0]}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location="cpu")
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if "string_to_param" in data: # textual inversion embeddings
data = data["string_to_param"]
if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, "_parameters")
emb = next(iter(data.values()))
if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return [emb]
def train(self, args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
@@ -91,7 +164,8 @@ def train(args):
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する
print("prepare accelerator")
@@ -101,29 +175,60 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
accelerator.print(
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
+ "accelerateでは複数のモデルテキストエンコーダーのgradient_accumulation_stepsはサポートされていないようです"
)
# Convert the init_word to token_id
init_token_ids_list = []
if args.init_word is not None:
for i, tokenizer in enumerate(tokenizers):
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
accelerator.print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
+ f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
)
init_token_ids_list.append(init_token_ids)
else:
init_token_ids = None
init_token_ids_list = [None] * len(tokenizers)
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
accelerator.print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}")
assert (
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
), f"token ids is not ordered : tokenizer {i+1}, {token_ids}"
assert (
len(tokenizer) - 1 == token_ids[-1]
), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"
token_ids_list.append(token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -134,14 +239,16 @@ def train(args):
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
token_embeds_list.append(token_embeds)
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
embeddings_list = self.load_weights(args.weights)
assert len(token_ids) == len(
embeddings
embeddings_list[0]
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# accelerator.print(token_ids, embeddings.size())
for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list):
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
@@ -190,10 +297,12 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
self.assert_extra_args(args, train_dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -210,11 +319,13 @@ def train(args):
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
# サンプル生成用
if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
@@ -236,6 +347,7 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 学習を準備する
if cache_latents:
@@ -253,11 +365,14 @@ def train(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for text_encoder in text_encoders:
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
trainable_params = []
for text_encoder in text_encoders:
trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
@@ -277,7 +392,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -286,16 +403,34 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
index_no_updates_list = []
orig_embeds_params_list = []
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
index_no_updates_list.append(index_no_updates)
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
orig_embeds_params_list.append(orig_embeds_params)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
@@ -307,6 +442,7 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
unet.train()
else:
unet.eval()
@@ -319,6 +455,10 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
for text_encoder in text_encoders:
text_encoder.to(weight_dtype)
if args.full_bf16:
for text_encoder in text_encoders:
text_encoder.to(weight_dtype)
# resumeする
@@ -356,12 +496,12 @@ def train(args):
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
self.save_weights(ckpt_file, embs_list, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -376,34 +516,36 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for text_encoder in text_encoders:
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
@@ -435,29 +577,48 @@ def train(args):
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
]
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
self.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs_list = []
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
updated_embs = (
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[token_ids]
.data.detach()
.clone()
)
updated_embs_list.append(updated_embs)
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch)
save_model(ckpt_name, updated_embs_list, global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -492,13 +653,16 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs_list = []
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
updated_embs_list.append(updated_embs)
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if accelerator.is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
save_model(ckpt_name, updated_embs_list, epoch + 1, global_step)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -508,8 +672,17 @@ def train(args):
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
self.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# end of epoch
@@ -525,58 +698,13 @@ def train(args):
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location="cpu")
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if "string_to_param" in data: # textual inversion embeddings
data = data["string_to_param"]
if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, "_parameters")
emb = next(iter(data.values()))
if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return emb
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
@@ -626,4 +754,5 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
trainer = TextualInversionTrainer()
trainer.train(args)