support textual inversion training

This commit is contained in:
Kohya S
2023-07-10 22:04:02 +09:00
parent b6e328ea8f
commit f54b784d88
5 changed files with 787 additions and 446 deletions

View File

@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
Summary of the feature: Summary of the feature:
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage. - `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows. - I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options: - Both scripts has following additional options:
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. - `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision. - The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
- `--weighted_captions` option is not supported yet.
- `--weighted_captions` option is not supported yet for both scripts.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. - `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage. - `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
`requirements.txt` is updated to support SDXL training. `requirements.txt` is updated to support SDXL training.
@@ -54,7 +67,7 @@ Summary of the feature:
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. - `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate: Example of the optimizer settings for Adafactor with the fixed learning rate:
``` ```toml
optimizer_type = "adafactor" optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup" lr_scheduler = "constant_with_warmup"
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate learning_rate = 4e-7 # SDXL original learning rate
``` ```
### Format of Textual Inversion embeddings
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### TODO ### TODO
- [ ] Support Textual Inversion training.
- [ ] Support conversion of Diffusers SDXL models. - [ ] Support conversion of Diffusers SDXL models.
- [ ] Support `--weighted_captions` option. - [ ] Support `--weighted_captions` option.
- [ ] Change `--output_config` option to continue the training. - [ ] Change `--output_config` option to continue the training.
- [ ] Extend `--full_bf16` for all the scripts. - [ ] Extend `--full_bf16` for all the scripts.
- [x] Support Textual Inversion training.
## About requirements.txt ## About requirements.txt

View File

@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
class WrapperTokenizer: class WrapperTokenizer:
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
# make open clip tokenizer compatible with HuggingFace tokenizer
def __init__(self): def __init__(self):
open_clip_tokenizer = open_clip.tokenizer._tokenizer open_clip_tokenizer = open_clip.tokenizer._tokenizer
self.model_max_length = 77 self.model_max_length = 77
self.bos_token_id = open_clip_tokenizer.all_special_ids[0] self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
self.eos_token_id = open_clip_tokenizer.all_special_ids[1] self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
self.pad_token_id = 0 # 結果から推定している self.pad_token_id = 0 # 結果から推定している assumption from result
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds) return self.tokenize(*args, **kwds)
@@ -107,6 +108,42 @@ class WrapperTokenizer:
input_ids = input_ids[: eos_index + 1] # include eos input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids}) return SimpleNamespace(**{"input_ids": input_ids})
# for Textual Inversion
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
def encode(self, text, add_special_tokens=False):
assert not add_special_tokens
input_ids = open_clip.tokenizer._tokenizer.encode(text)
return input_ids
def add_tokens(self, new_tokens):
tokens_to_add = []
for token in new_tokens:
token = token.lower()
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
tokens_to_add.append(token)
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
for token in tokens_to_add:
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
open_clip.tokenizer._tokenizer.vocab_size += 1
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
# this is very rough, so it will not work if the specification of open clip tokenizer changes
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
return len(tokens_to_add)
def convert_tokens_to_ids(self, tokens):
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
return input_ids
def __len__(self):
return open_clip.tokenizer._tokenizer.vocab_size
def load_tokenizers(args: argparse.Namespace): def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers") print("prepare tokenizers")
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert ( assert (
not args.weighted_captions not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"

View File

@@ -320,7 +320,7 @@ class PipelineLike:
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
# Textual Inversion # not tested yet # Textual Inversion
self.token_replacements_list = [] self.token_replacements_list = []
for _ in range(len(self.text_encoders)): for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({}) self.token_replacements_list.append({})
@@ -341,6 +341,10 @@ class PipelineLike:
token_replacements = self.token_replacements_list[tokenizer_index] token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens): def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
new_tokens = [] new_tokens = []
for token in tokens: for token in tokens:
if token in token_replacements: if token in token_replacements:
@@ -1594,19 +1598,26 @@ def main(args):
if "string_to_param" in data: if "string_to_param" in data:
data = data["string_to_param"] data = data["string_to_param"]
embeds1 = data["clip_l"]
embeds2 = data["clip_g"] embeds1 = data["clip_l"] # text encoder 1
embeds2 = data["clip_g"] # text encoder 2
num_vectors_per_token = embeds1.size()[0] num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token # add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings) num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert ( assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
)
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
@@ -1617,11 +1628,11 @@ def main(args):
assert ( assert (
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
), f"token ids2 is not ordered" ), f"token ids2 is not ordered"
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}" assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}" assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
if num_vectors_per_token > 1: if num_vectors_per_token > 1:
pipe.add_token_replacement(0, token_ids1[0], token_ids1) pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
pipe.add_token_replacement(1, token_ids2[0], token_ids2) pipe.add_token_replacement(1, token_ids2[0], token_ids2)
token_ids_embeds1.append((token_ids1, embeds1)) token_ids_embeds1.append((token_ids1, embeds1))

View File

@@ -0,0 +1,142 @@
import argparse
import os
import regex
import torch
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def assert_token_string(self, token_string, tokenizers):
# tokenizer 1 is seems to be ok
# count words for token string: regular expression from open_clip
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
words = regex.findall(pat, token_string)
word_count = len(words)
assert word_count == 1, (
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
args,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype):
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file)
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
data = torch.load(file, map_location="cpu")
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
assert (
emb_l is not None or emb_g is not None
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
return [emb_l, emb_g]
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()
trainer.train(args)

File diff suppressed because it is too large Load Diff