mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support textual inversion training
This commit is contained in:
@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している
|
||||
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
@@ -107,6 +108,42 @@ class WrapperTokenizer:
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for Textual Inversion
|
||||
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
assert not add_special_tokens
|
||||
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||
return input_ids
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
token = token.lower()
|
||||
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||
tokens_to_add.append(token)
|
||||
|
||||
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||
for token in tokens_to_add:
|
||||
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||
|
||||
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return open_clip.tokenizer._tokenizer.vocab_size
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user