mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support textual inversion training
This commit is contained in:
30
README.md
30
README.md
@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
||||
Summary of the feature:
|
||||
|
||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
||||
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||
- Both scripts has following additional options:
|
||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
||||
- `--weighted_captions` option is not supported yet.
|
||||
|
||||
- `--weighted_captions` option is not supported yet for both scripts.
|
||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||
|
||||
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||
- `--cache_text_encoder_outputs` is not supported.
|
||||
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
|
||||
- There are two options for captions:
|
||||
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||
- See below for the format of the embeddings.
|
||||
|
||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
||||
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
|
||||
|
||||
`requirements.txt` is updated to support SDXL training.
|
||||
|
||||
@@ -54,7 +67,7 @@ Summary of the feature:
|
||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||
|
||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||
```
|
||||
```toml
|
||||
optimizer_type = "adafactor"
|
||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||
lr_scheduler = "constant_with_warmup"
|
||||
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
|
||||
learning_rate = 4e-7 # SDXL original learning rate
|
||||
```
|
||||
|
||||
### Format of Textual Inversion embeddings
|
||||
|
||||
```python
|
||||
from safetensors.torch import save_file
|
||||
|
||||
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||
save_file(state_dict, file)
|
||||
```
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] Support Textual Inversion training.
|
||||
- [ ] Support conversion of Diffusers SDXL models.
|
||||
- [ ] Support `--weighted_captions` option.
|
||||
- [ ] Change `--output_config` option to continue the training.
|
||||
- [ ] Extend `--full_bf16` for all the scripts.
|
||||
- [x] Support Textual Inversion training.
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
|
||||
@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している
|
||||
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
@@ -107,6 +108,42 @@ class WrapperTokenizer:
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for Textual Inversion
|
||||
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
assert not add_special_tokens
|
||||
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||
return input_ids
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
token = token.lower()
|
||||
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||
tokens_to_add.append(token)
|
||||
|
||||
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||
for token in tokens_to_add:
|
||||
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||
|
||||
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return open_clip.tokenizer._tokenizer.vocab_size
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
|
||||
|
||||
@@ -320,7 +320,7 @@ class PipelineLike:
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
# Textual Inversion # not tested yet
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
self.token_replacements_list.append({})
|
||||
@@ -341,6 +341,10 @@ class PipelineLike:
|
||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||
|
||||
def replace_tokens(tokens):
|
||||
# print("replace_tokens", tokens, "=>", token_replacements)
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
if token in token_replacements:
|
||||
@@ -1594,19 +1598,26 @@ def main(args):
|
||||
|
||||
if "string_to_param" in data:
|
||||
data = data["string_to_param"]
|
||||
embeds1 = data["clip_l"]
|
||||
embeds2 = data["clip_g"]
|
||||
|
||||
embeds1 = data["clip_l"] # text encoder 1
|
||||
embeds2 = data["clip_g"] # text encoder 2
|
||||
|
||||
num_vectors_per_token = embeds1.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||
# TODO make random alphabet string
|
||||
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||
|
||||
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
|
||||
assert (
|
||||
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
|
||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||
)
|
||||
|
||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||
@@ -1617,11 +1628,11 @@ def main(args):
|
||||
assert (
|
||||
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
||||
), f"token ids2 is not ordered"
|
||||
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
|
||||
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
|
||||
|
||||
if num_vectors_per_token > 1:
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
|
||||
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
||||
|
||||
token_ids_embeds1.append((token_ids1, embeds1))
|
||||
|
||||
142
sdxl_train_textual_inversion.py
Normal file
142
sdxl_train_textual_inversion.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
import train_textual_inversion
|
||||
|
||||
|
||||
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
# tokenizer 1 is seems to be ok
|
||||
|
||||
# count words for token string: regular expression from open_clip
|
||||
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
words = regex.findall(pat, token_string)
|
||||
word_count = len(words)
|
||||
assert word_count == 1, (
|
||||
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
|
||||
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
|
||||
)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.enable_grad():
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||
args,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype):
|
||||
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
data = torch.load(file, map_location="cpu")
|
||||
|
||||
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
|
||||
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
|
||||
|
||||
assert (
|
||||
emb_l is not None or emb_g is not None
|
||||
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
||||
|
||||
return [emb_l, emb_g]
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_textual_inversion.setup_parser()
|
||||
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = SdxlTextualInversionTrainer()
|
||||
trainer.train(args)
|
||||
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
@@ -20,8 +21,6 @@ import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
@@ -78,7 +77,81 @@ imagenet_style_templates_small = [
|
||||
]
|
||||
|
||||
|
||||
def train(args):
|
||||
class TextualInversionTrainer:
|
||||
def __init__(self):
|
||||
self.vae_scale_factor = 0.18215
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
pass
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
return tokenizer
|
||||
|
||||
def assert_token_string(self, token_string, tokenizers):
|
||||
pass
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
with torch.enable_grad():
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
|
||||
return encoder_hidden_states
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype):
|
||||
state_dict = {"emb_params": updated_embs[0]}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file) # can be loaded in Web UI
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
# compatible to Web UI's file format
|
||||
data = torch.load(file, map_location="cpu")
|
||||
if type(data) != dict:
|
||||
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
||||
|
||||
if "string_to_param" in data: # textual inversion embeddings
|
||||
data = data["string_to_param"]
|
||||
if hasattr(data, "_parameters"): # support old PyTorch?
|
||||
data = getattr(data, "_parameters")
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if type(emb) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
||||
|
||||
if len(emb.size()) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
|
||||
return [emb]
|
||||
|
||||
def train(self, args):
|
||||
if args.output_name is None:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
@@ -91,7 +164,8 @@ def train(args):
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
|
||||
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
@@ -101,29 +175,60 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
|
||||
|
||||
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
|
||||
accelerator.print(
|
||||
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
|
||||
+ "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
|
||||
)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
init_token_ids_list = []
|
||||
if args.init_word is not None:
|
||||
for i, tokenizer in enumerate(tokenizers):
|
||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||
accelerator.print(
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
|
||||
+ f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
|
||||
)
|
||||
init_token_ids_list.append(init_token_ids)
|
||||
else:
|
||||
init_token_ids = None
|
||||
init_token_ids_list = [None] * len(tokenizers)
|
||||
|
||||
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
|
||||
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
|
||||
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
|
||||
|
||||
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
|
||||
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
|
||||
token_strings = [args.token_string] + [
|
||||
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
|
||||
]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||
assert (
|
||||
num_added_tokens == args.num_vectors_per_token
|
||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
accelerator.print(f"tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}")
|
||||
assert (
|
||||
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
|
||||
), f"token ids is not ordered : tokenizer {i+1}, {token_ids}"
|
||||
assert (
|
||||
len(tokenizer) - 1 == token_ids[-1]
|
||||
), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
@@ -134,14 +239,16 @@ def train(args):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
||||
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
token_embeds_list.append(token_embeds)
|
||||
|
||||
# load weights
|
||||
if args.weights is not None:
|
||||
embeddings = load_weights(args.weights)
|
||||
embeddings_list = self.load_weights(args.weights)
|
||||
assert len(token_ids) == len(
|
||||
embeddings
|
||||
embeddings_list[0]
|
||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||
# accelerator.print(token_ids, embeddings.size())
|
||||
for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list):
|
||||
for token_id, embedding in zip(token_ids, embeddings):
|
||||
token_embeds[token_id] = embedding
|
||||
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
@@ -190,10 +297,12 @@ def train(args):
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -210,11 +319,13 @@ def train(args):
|
||||
captions.append(tmpl.format(replace_to))
|
||||
train_dataset_group.add_replacement("", captions)
|
||||
|
||||
# サンプル生成用
|
||||
if args.num_vectors_per_token > 1:
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
else:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
# サンプル生成用
|
||||
if args.num_vectors_per_token > 1:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||
@@ -236,6 +347,7 @@ def train(args):
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -253,11 +365,14 @@ def train(args):
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
trainable_params = []
|
||||
for text_encoder in text_encoders:
|
||||
trainable_params += text_encoder.get_input_embeddings().parameters()
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
@@ -277,7 +392,9 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -286,16 +403,34 @@ def train(args):
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
if len(text_encoders) == 1:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
# transform DDP after prepare
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
index_no_updates_list = []
|
||||
orig_embeds_params_list = []
|
||||
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
index_no_updates_list.append(index_no_updates)
|
||||
|
||||
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
orig_embeds_params_list.append(orig_embeds_params)
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.requires_grad_(True)
|
||||
@@ -307,6 +442,7 @@ def train(args):
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
|
||||
unet.train()
|
||||
else:
|
||||
unet.eval()
|
||||
@@ -319,6 +455,10 @@ def train(args):
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(weight_dtype)
|
||||
if args.full_bf16:
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# resumeする
|
||||
@@ -356,12 +496,12 @@ def train(args):
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
||||
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
self.save_weights(ckpt_file, embs_list, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
@@ -376,34 +516,36 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.train()
|
||||
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with accelerator.accumulate(text_encoders[0]):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
latents = latents * self.vae_scale_factor
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
# use float instead of fp16/bf16 because text encoder is float
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
||||
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@@ -435,29 +577,48 @@ def train(args):
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||
):
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
]
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs_list = []
|
||||
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
||||
updated_embs = (
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[token_ids]
|
||||
.data.detach()
|
||||
.clone()
|
||||
)
|
||||
updated_embs_list.append(updated_embs)
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
save_model(ckpt_name, updated_embs_list, global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
@@ -492,13 +653,16 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs_list = []
|
||||
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
||||
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
updated_embs_list.append(updated_embs)
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if accelerator.is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
||||
save_model(ckpt_name, updated_embs_list, epoch + 1, global_step)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
@@ -508,8 +672,17 @@ def train(args):
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# end of epoch
|
||||
@@ -525,58 +698,13 @@ def train(args):
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def save_weights(file, updated_embs, save_dtype):
|
||||
state_dict = {"emb_params": updated_embs}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file) # can be loaded in Web UI
|
||||
|
||||
|
||||
def load_weights(file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
# compatible to Web UI's file format
|
||||
data = torch.load(file, map_location="cpu")
|
||||
if type(data) != dict:
|
||||
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
||||
|
||||
if "string_to_param" in data: # textual inversion embeddings
|
||||
data = data["string_to_param"]
|
||||
if hasattr(data, "_parameters"): # support old PyTorch?
|
||||
data = getattr(data, "_parameters")
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if type(emb) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
||||
|
||||
if len(emb.size()) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -626,4 +754,5 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
trainer = TextualInversionTrainer()
|
||||
trainer.train(args)
|
||||
|
||||
Reference in New Issue
Block a user