mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support textual inversion training
This commit is contained in:
30
README.md
30
README.md
@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
||||
Summary of the feature:
|
||||
|
||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
||||
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||
- Both scripts has following additional options:
|
||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
||||
- `--weighted_captions` option is not supported yet.
|
||||
|
||||
- `--weighted_captions` option is not supported yet for both scripts.
|
||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||
|
||||
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||
- `--cache_text_encoder_outputs` is not supported.
|
||||
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
|
||||
- There are two options for captions:
|
||||
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||
- See below for the format of the embeddings.
|
||||
|
||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
||||
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
|
||||
|
||||
`requirements.txt` is updated to support SDXL training.
|
||||
|
||||
@@ -54,7 +67,7 @@ Summary of the feature:
|
||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||
|
||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||
```
|
||||
```toml
|
||||
optimizer_type = "adafactor"
|
||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||
lr_scheduler = "constant_with_warmup"
|
||||
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
|
||||
learning_rate = 4e-7 # SDXL original learning rate
|
||||
```
|
||||
|
||||
### Format of Textual Inversion embeddings
|
||||
|
||||
```python
|
||||
from safetensors.torch import save_file
|
||||
|
||||
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||
save_file(state_dict, file)
|
||||
```
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] Support Textual Inversion training.
|
||||
- [ ] Support conversion of Diffusers SDXL models.
|
||||
- [ ] Support `--weighted_captions` option.
|
||||
- [ ] Change `--output_config` option to continue the training.
|
||||
- [ ] Extend `--full_bf16` for all the scripts.
|
||||
- [x] Support Textual Inversion training.
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
|
||||
@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している
|
||||
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
@@ -107,6 +108,42 @@ class WrapperTokenizer:
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for Textual Inversion
|
||||
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
assert not add_special_tokens
|
||||
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||
return input_ids
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
token = token.lower()
|
||||
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||
tokens_to_add.append(token)
|
||||
|
||||
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||
for token in tokens_to_add:
|
||||
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||
|
||||
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return open_clip.tokenizer._tokenizer.vocab_size
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
|
||||
|
||||
@@ -320,7 +320,7 @@ class PipelineLike:
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
# Textual Inversion # not tested yet
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
self.token_replacements_list.append({})
|
||||
@@ -341,6 +341,10 @@ class PipelineLike:
|
||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||
|
||||
def replace_tokens(tokens):
|
||||
# print("replace_tokens", tokens, "=>", token_replacements)
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
if token in token_replacements:
|
||||
@@ -1594,19 +1598,26 @@ def main(args):
|
||||
|
||||
if "string_to_param" in data:
|
||||
data = data["string_to_param"]
|
||||
embeds1 = data["clip_l"]
|
||||
embeds2 = data["clip_g"]
|
||||
|
||||
embeds1 = data["clip_l"] # text encoder 1
|
||||
embeds2 = data["clip_g"] # text encoder 2
|
||||
|
||||
num_vectors_per_token = embeds1.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||
# TODO make random alphabet string
|
||||
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||
|
||||
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
|
||||
assert (
|
||||
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
|
||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||
)
|
||||
|
||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||
@@ -1617,11 +1628,11 @@ def main(args):
|
||||
assert (
|
||||
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
||||
), f"token ids2 is not ordered"
|
||||
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
|
||||
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
|
||||
|
||||
if num_vectors_per_token > 1:
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
|
||||
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
||||
|
||||
token_ids_embeds1.append((token_ids1, embeds1))
|
||||
|
||||
142
sdxl_train_textual_inversion.py
Normal file
142
sdxl_train_textual_inversion.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
import train_textual_inversion
|
||||
|
||||
|
||||
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
# tokenizer 1 is seems to be ok
|
||||
|
||||
# count words for token string: regular expression from open_clip
|
||||
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
words = regex.findall(pat, token_string)
|
||||
word_count = len(words)
|
||||
assert word_count == 1, (
|
||||
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
|
||||
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
|
||||
)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.enable_grad():
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||
args,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype):
|
||||
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
data = torch.load(file, map_location="cpu")
|
||||
|
||||
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
|
||||
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
|
||||
|
||||
assert (
|
||||
emb_l is not None or emb_g is not None
|
||||
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
||||
|
||||
return [emb_l, emb_g]
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_textual_inversion.setup_parser()
|
||||
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = SdxlTextualInversionTrainer()
|
||||
trainer.train(args)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user