mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support textual inversion training
This commit is contained in:
30
README.md
30
README.md
@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
|||||||
Summary of the feature:
|
Summary of the feature:
|
||||||
|
|
||||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||||
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||||
|
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||||
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||||
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||||
|
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
|
||||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||||
- Both scripts has following additional options:
|
- Both scripts has following additional options:
|
||||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||||
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
||||||
- `--weighted_captions` option is not supported yet.
|
|
||||||
|
- `--weighted_captions` option is not supported yet for both scripts.
|
||||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||||
|
|
||||||
|
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||||
|
- `--cache_text_encoder_outputs` is not supported.
|
||||||
|
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
|
||||||
|
- There are two options for captions:
|
||||||
|
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||||
|
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||||
|
- See below for the format of the embeddings.
|
||||||
|
|
||||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
||||||
|
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
|
||||||
|
|
||||||
`requirements.txt` is updated to support SDXL training.
|
`requirements.txt` is updated to support SDXL training.
|
||||||
|
|
||||||
@@ -54,7 +67,7 @@ Summary of the feature:
|
|||||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||||
|
|
||||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||||
```
|
```toml
|
||||||
optimizer_type = "adafactor"
|
optimizer_type = "adafactor"
|
||||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||||
lr_scheduler = "constant_with_warmup"
|
lr_scheduler = "constant_with_warmup"
|
||||||
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
|
|||||||
learning_rate = 4e-7 # SDXL original learning rate
|
learning_rate = 4e-7 # SDXL original learning rate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Format of Textual Inversion embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||||
|
save_file(state_dict, file)
|
||||||
|
```
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
|
|
||||||
- [ ] Support Textual Inversion training.
|
|
||||||
- [ ] Support conversion of Diffusers SDXL models.
|
- [ ] Support conversion of Diffusers SDXL models.
|
||||||
- [ ] Support `--weighted_captions` option.
|
- [ ] Support `--weighted_captions` option.
|
||||||
- [ ] Change `--output_config` option to continue the training.
|
- [ ] Change `--output_config` option to continue the training.
|
||||||
- [ ] Extend `--full_bf16` for all the scripts.
|
- [ ] Extend `--full_bf16` for all the scripts.
|
||||||
|
- [x] Support Textual Inversion training.
|
||||||
|
|
||||||
## About requirements.txt
|
## About requirements.txt
|
||||||
|
|
||||||
|
|||||||
@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
|
|
||||||
class WrapperTokenizer:
|
class WrapperTokenizer:
|
||||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||||
|
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||||
self.model_max_length = 77
|
self.model_max_length = 77
|
||||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||||
self.pad_token_id = 0 # 結果から推定している
|
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
return self.tokenize(*args, **kwds)
|
return self.tokenize(*args, **kwds)
|
||||||
@@ -107,6 +108,42 @@ class WrapperTokenizer:
|
|||||||
input_ids = input_ids[: eos_index + 1] # include eos
|
input_ids = input_ids[: eos_index + 1] # include eos
|
||||||
return SimpleNamespace(**{"input_ids": input_ids})
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
|
# for Textual Inversion
|
||||||
|
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||||
|
|
||||||
|
def encode(self, text, add_special_tokens=False):
|
||||||
|
assert not add_special_tokens
|
||||||
|
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def add_tokens(self, new_tokens):
|
||||||
|
tokens_to_add = []
|
||||||
|
for token in new_tokens:
|
||||||
|
token = token.lower()
|
||||||
|
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||||
|
tokens_to_add.append(token)
|
||||||
|
|
||||||
|
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||||
|
for token in tokens_to_add:
|
||||||
|
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||||
|
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||||
|
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||||
|
|
||||||
|
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||||
|
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||||
|
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||||
|
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||||
|
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||||
|
|
||||||
|
return len(tokens_to_add)
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return open_clip.tokenizer._tokenizer.vocab_size
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizers(args: argparse.Namespace):
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
print("prepare tokenizers")
|
print("prepare tokenizers")
|
||||||
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
|||||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
not args.weighted_captions
|
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ class PipelineLike:
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
# Textual Inversion # not tested yet
|
# Textual Inversion
|
||||||
self.token_replacements_list = []
|
self.token_replacements_list = []
|
||||||
for _ in range(len(self.text_encoders)):
|
for _ in range(len(self.text_encoders)):
|
||||||
self.token_replacements_list.append({})
|
self.token_replacements_list.append({})
|
||||||
@@ -341,6 +341,10 @@ class PipelineLike:
|
|||||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||||
|
|
||||||
def replace_tokens(tokens):
|
def replace_tokens(tokens):
|
||||||
|
# print("replace_tokens", tokens, "=>", token_replacements)
|
||||||
|
if isinstance(tokens, torch.Tensor):
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
new_tokens = []
|
new_tokens = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token in token_replacements:
|
if token in token_replacements:
|
||||||
@@ -1594,19 +1598,26 @@ def main(args):
|
|||||||
|
|
||||||
if "string_to_param" in data:
|
if "string_to_param" in data:
|
||||||
data = data["string_to_param"]
|
data = data["string_to_param"]
|
||||||
embeds1 = data["clip_l"]
|
|
||||||
embeds2 = data["clip_g"]
|
embeds1 = data["clip_l"] # text encoder 1
|
||||||
|
embeds2 = data["clip_g"] # text encoder 2
|
||||||
|
|
||||||
num_vectors_per_token = embeds1.size()[0]
|
num_vectors_per_token = embeds1.size()[0]
|
||||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
|
||||||
|
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||||
|
# TODO make random alphabet string
|
||||||
|
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||||
|
|
||||||
|
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
# add new word to tokenizer, count is num_vectors_per_token
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
|
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||||
assert (
|
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||||
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
|
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||||
|
)
|
||||||
|
|
||||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||||
@@ -1617,11 +1628,11 @@ def main(args):
|
|||||||
assert (
|
assert (
|
||||||
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
||||||
), f"token ids2 is not ordered"
|
), f"token ids2 is not ordered"
|
||||||
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
|
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
|
||||||
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
|
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
|
||||||
|
|
||||||
if num_vectors_per_token > 1:
|
if num_vectors_per_token > 1:
|
||||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
|
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
|
||||||
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
||||||
|
|
||||||
token_ids_embeds1.append((token_ids1, embeds1))
|
token_ids_embeds1.append((token_ids1, embeds1))
|
||||||
|
|||||||
142
sdxl_train_textual_inversion.py
Normal file
142
sdxl_train_textual_inversion.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import regex
|
||||||
|
import torch
|
||||||
|
import open_clip
|
||||||
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
|
|
||||||
|
import train_textual_inversion
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
(
|
||||||
|
load_stable_diffusion_format,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||||
|
|
||||||
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
|
self.logit_scale = logit_scale
|
||||||
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
|
def load_tokenizer(self, args):
|
||||||
|
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def assert_token_string(self, token_string, tokenizers):
|
||||||
|
# tokenizer 1 is seems to be ok
|
||||||
|
|
||||||
|
# count words for token string: regular expression from open_clip
|
||||||
|
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||||
|
words = regex.findall(pat, token_string)
|
||||||
|
word_count = len(words)
|
||||||
|
assert word_count == 1, (
|
||||||
|
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
|
||||||
|
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
|
input_ids1 = batch["input_ids"]
|
||||||
|
input_ids2 = batch["input_ids2"]
|
||||||
|
with torch.enable_grad():
|
||||||
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||||
|
args,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizers[0],
|
||||||
|
tokenizers[1],
|
||||||
|
text_encoders[0],
|
||||||
|
text_encoders[1],
|
||||||
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
)
|
||||||
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
|
# get size embeddings
|
||||||
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
crop_size = batch["crop_top_lefts"]
|
||||||
|
target_size = batch["target_sizes_hw"]
|
||||||
|
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||||
|
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||||
|
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||||
|
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||||
|
sdxl_train_util.sample_images(
|
||||||
|
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_weights(self, file, updated_embs, save_dtype):
|
||||||
|
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
||||||
|
|
||||||
|
if save_dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
save_file(state_dict, file)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
data = load_file(file)
|
||||||
|
else:
|
||||||
|
data = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
|
||||||
|
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
|
||||||
|
|
||||||
|
assert (
|
||||||
|
emb_l is not None or emb_g is not None
|
||||||
|
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
||||||
|
|
||||||
|
return [emb_l, emb_g]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = train_textual_inversion.setup_parser()
|
||||||
|
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||||
|
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
trainer = SdxlTextualInversionTrainer()
|
||||||
|
trainer.train(args)
|
||||||
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
|||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
from library import model_util
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
@@ -20,8 +21,6 @@ import library.custom_train_functions as custom_train_functions
|
|||||||
from library.custom_train_functions import (
|
from library.custom_train_functions import (
|
||||||
apply_snr_weight,
|
apply_snr_weight,
|
||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
pyramid_noise_like,
|
|
||||||
apply_noise_offset,
|
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,7 +77,81 @@ imagenet_style_templates_small = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
class TextualInversionTrainer:
|
||||||
|
def __init__(self):
|
||||||
|
self.vae_scale_factor = 0.18215
|
||||||
|
|
||||||
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||||
|
|
||||||
|
def load_tokenizer(self, args):
|
||||||
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def assert_token_string(self, token_string, tokenizers):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
|
with torch.enable_grad():
|
||||||
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
|
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
|
||||||
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||||
|
train_util.sample_images(
|
||||||
|
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_weights(self, file, updated_embs, save_dtype):
|
||||||
|
state_dict = {"emb_params": updated_embs[0]}
|
||||||
|
|
||||||
|
if save_dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
save_file(state_dict, file)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file) # can be loaded in Web UI
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
data = load_file(file)
|
||||||
|
else:
|
||||||
|
# compatible to Web UI's file format
|
||||||
|
data = torch.load(file, map_location="cpu")
|
||||||
|
if type(data) != dict:
|
||||||
|
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
||||||
|
|
||||||
|
if "string_to_param" in data: # textual inversion embeddings
|
||||||
|
data = data["string_to_param"]
|
||||||
|
if hasattr(data, "_parameters"): # support old PyTorch?
|
||||||
|
data = getattr(data, "_parameters")
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if type(emb) != torch.Tensor:
|
||||||
|
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
||||||
|
|
||||||
|
if len(emb.size()) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
|
||||||
|
return [emb]
|
||||||
|
|
||||||
|
def train(self, args):
|
||||||
if args.output_name is None:
|
if args.output_name is None:
|
||||||
args.output_name = args.token_string
|
args.output_name = args.token_string
|
||||||
use_template = args.use_object_template or args.use_style_template
|
use_template = args.use_object_template or args.use_style_template
|
||||||
@@ -91,7 +164,8 @@ def train(args):
|
|||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
|
||||||
|
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
@@ -101,29 +175,60 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
|
||||||
|
|
||||||
|
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
|
||||||
|
accelerator.print(
|
||||||
|
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
|
||||||
|
+ "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
|
||||||
|
)
|
||||||
|
|
||||||
# Convert the init_word to token_id
|
# Convert the init_word to token_id
|
||||||
|
init_token_ids_list = []
|
||||||
if args.init_word is not None:
|
if args.init_word is not None:
|
||||||
|
for i, tokenizer in enumerate(tokenizers):
|
||||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||||
accelerator.print(
|
accelerator.print(
|
||||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
|
||||||
|
+ f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
|
||||||
)
|
)
|
||||||
|
init_token_ids_list.append(init_token_ids)
|
||||||
else:
|
else:
|
||||||
init_token_ids = None
|
init_token_ids_list = [None] * len(tokenizers)
|
||||||
|
|
||||||
|
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
|
||||||
# add new word to tokenizer, count is num_vectors_per_token
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
|
||||||
|
# token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される
|
||||||
|
# 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした
|
||||||
|
|
||||||
|
# if token_string is hoge, "hoge", "hogea", "hogeb", ... are added
|
||||||
|
# originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used
|
||||||
|
|
||||||
|
self.assert_token_string(args.token_string, tokenizers)
|
||||||
|
|
||||||
|
token_strings = [args.token_string] + [
|
||||||
|
f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1)
|
||||||
|
]
|
||||||
|
token_ids_list = []
|
||||||
|
token_embeds_list = []
|
||||||
|
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
assert (
|
assert (
|
||||||
num_added_tokens == args.num_vectors_per_token
|
num_added_tokens == args.num_vectors_per_token
|
||||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}"
|
||||||
|
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
accelerator.print(f"tokens are added: {token_ids}")
|
accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}")
|
||||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
assert (
|
||||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
|
||||||
|
), f"token ids is not ordered : tokenizer {i+1}, {token_ids}"
|
||||||
|
assert (
|
||||||
|
len(tokenizer) - 1 == token_ids[-1]
|
||||||
|
), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
@@ -134,14 +239,16 @@ def train(args):
|
|||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
||||||
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
token_embeds_list.append(token_embeds)
|
||||||
|
|
||||||
# load weights
|
# load weights
|
||||||
if args.weights is not None:
|
if args.weights is not None:
|
||||||
embeddings = load_weights(args.weights)
|
embeddings_list = self.load_weights(args.weights)
|
||||||
assert len(token_ids) == len(
|
assert len(token_ids) == len(
|
||||||
embeddings
|
embeddings_list[0]
|
||||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||||
# accelerator.print(token_ids, embeddings.size())
|
# accelerator.print(token_ids, embeddings.size())
|
||||||
|
for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list):
|
||||||
for token_id, embedding in zip(token_ids, embeddings):
|
for token_id, embedding in zip(token_ids, embeddings):
|
||||||
token_embeds[token_id] = embedding
|
token_embeds[token_id] = embedding
|
||||||
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
@@ -190,10 +297,12 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
|
||||||
|
|
||||||
|
self.assert_extra_args(args, train_dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
@@ -210,11 +319,13 @@ def train(args):
|
|||||||
captions.append(tmpl.format(replace_to))
|
captions.append(tmpl.format(replace_to))
|
||||||
train_dataset_group.add_replacement("", captions)
|
train_dataset_group.add_replacement("", captions)
|
||||||
|
|
||||||
|
# サンプル生成用
|
||||||
if args.num_vectors_per_token > 1:
|
if args.num_vectors_per_token > 1:
|
||||||
prompt_replacement = (args.token_string, replace_to)
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
else:
|
else:
|
||||||
prompt_replacement = None
|
prompt_replacement = None
|
||||||
else:
|
else:
|
||||||
|
# サンプル生成用
|
||||||
if args.num_vectors_per_token > 1:
|
if args.num_vectors_per_token > 1:
|
||||||
replace_to = " ".join(token_strings)
|
replace_to = " ".join(token_strings)
|
||||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||||
@@ -236,6 +347,7 @@ def train(args):
|
|||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
@@ -253,11 +365,14 @@ def train(args):
|
|||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
for text_encoder in text_encoders:
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
trainable_params = []
|
||||||
|
for text_encoder in text_encoders:
|
||||||
|
trainable_params += text_encoder.get_input_embeddings().parameters()
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
@@ -277,7 +392,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
accelerator.print(
|
||||||
|
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||||
|
)
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -286,16 +403,34 @@ def train(args):
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
if len(text_encoders) == 1:
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform DDP after prepare
|
# transform DDP after prepare
|
||||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
|
||||||
|
|
||||||
|
elif len(text_encoders) == 2:
|
||||||
|
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
|
||||||
|
|
||||||
|
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
index_no_updates_list = []
|
||||||
|
orig_embeds_params_list = []
|
||||||
|
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
|
index_no_updates_list.append(index_no_updates)
|
||||||
|
|
||||||
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
orig_embeds_params_list.append(orig_embeds_params)
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
text_encoder.requires_grad_(True)
|
text_encoder.requires_grad_(True)
|
||||||
@@ -307,6 +442,7 @@ def train(args):
|
|||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||||
|
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
|
||||||
unet.train()
|
unet.train()
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
@@ -319,6 +455,10 @@ def train(args):
|
|||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
for text_encoder in text_encoders:
|
||||||
|
text_encoder.to(weight_dtype)
|
||||||
|
if args.full_bf16:
|
||||||
|
for text_encoder in text_encoders:
|
||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
@@ -356,12 +496,12 @@ def train(args):
|
|||||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, embs, save_dtype)
|
self.save_weights(ckpt_file, embs_list, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||||
|
|
||||||
@@ -376,34 +516,36 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
for text_encoder in text_encoders:
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoders[0]):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * self.vae_scale_factor
|
||||||
b_size = latents.shape[0]
|
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
|
||||||
# use float instead of fp16/bf16 because text encoder is float
|
|
||||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = self.call_unet(
|
||||||
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
@@ -435,29 +577,48 @@ def train(args):
|
|||||||
|
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||||
|
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||||
|
):
|
||||||
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||||
index_no_updates
|
index_no_updates
|
||||||
]
|
] = orig_embeds_params[index_no_updates]
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
train_util.sample_images(
|
self.sample_images(
|
||||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
accelerator,
|
||||||
|
args,
|
||||||
|
None,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
tokenizer_or_list,
|
||||||
|
text_encoder_or_list,
|
||||||
|
unet,
|
||||||
|
prompt_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
|
updated_embs_list = []
|
||||||
|
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
||||||
updated_embs = (
|
updated_embs = (
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
accelerator.unwrap_model(text_encoder)
|
||||||
|
.get_input_embeddings()
|
||||||
|
.weight[token_ids]
|
||||||
|
.data.detach()
|
||||||
|
.clone()
|
||||||
)
|
)
|
||||||
|
updated_embs_list.append(updated_embs)
|
||||||
|
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
save_model(ckpt_name, updated_embs_list, global_step, epoch)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||||
@@ -492,13 +653,16 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
updated_embs_list = []
|
||||||
|
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
||||||
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
updated_embs_list.append(updated_embs)
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
if accelerator.is_main_process and saving:
|
if accelerator.is_main_process and saving:
|
||||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||||
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
save_model(ckpt_name, updated_embs_list, epoch + 1, global_step)
|
||||||
|
|
||||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||||
if remove_epoch_no is not None:
|
if remove_epoch_no is not None:
|
||||||
@@ -508,8 +672,17 @@ def train(args):
|
|||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||||
|
|
||||||
train_util.sample_images(
|
self.sample_images(
|
||||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
accelerator,
|
||||||
|
args,
|
||||||
|
epoch + 1,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
tokenizer_or_list,
|
||||||
|
text_encoder_or_list,
|
||||||
|
unet,
|
||||||
|
prompt_replacement,
|
||||||
)
|
)
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
@@ -525,58 +698,13 @@ def train(args):
|
|||||||
|
|
||||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def save_weights(file, updated_embs, save_dtype):
|
|
||||||
state_dict = {"emb_params": updated_embs}
|
|
||||||
|
|
||||||
if save_dtype is not None:
|
|
||||||
for key in list(state_dict.keys()):
|
|
||||||
v = state_dict[key]
|
|
||||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
|
||||||
state_dict[key] = v
|
|
||||||
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
save_file(state_dict, file)
|
|
||||||
else:
|
|
||||||
torch.save(state_dict, file) # can be loaded in Web UI
|
|
||||||
|
|
||||||
|
|
||||||
def load_weights(file):
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
data = load_file(file)
|
|
||||||
else:
|
|
||||||
# compatible to Web UI's file format
|
|
||||||
data = torch.load(file, map_location="cpu")
|
|
||||||
if type(data) != dict:
|
|
||||||
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
|
||||||
|
|
||||||
if "string_to_param" in data: # textual inversion embeddings
|
|
||||||
data = data["string_to_param"]
|
|
||||||
if hasattr(data, "_parameters"): # support old PyTorch?
|
|
||||||
data = getattr(data, "_parameters")
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if type(emb) != torch.Tensor:
|
|
||||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
|
||||||
|
|
||||||
if len(emb.size()) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -626,4 +754,5 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
train(args)
|
trainer = TextualInversionTrainer()
|
||||||
|
trainer.train(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user