support textual inversion training

This commit is contained in:
Kohya S
2023-07-10 22:04:02 +09:00
parent b6e328ea8f
commit f54b784d88
5 changed files with 787 additions and 446 deletions

View File

@@ -320,7 +320,7 @@ class PipelineLike:
self.scheduler = scheduler
self.safety_checker = None
# Textual Inversion # not tested yet
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({})
@@ -341,6 +341,10 @@ class PipelineLike:
token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
new_tokens = []
for token in tokens:
if token in token_replacements:
@@ -1594,19 +1598,26 @@ def main(args):
if "string_to_param" in data:
data = data["string_to_param"]
embeds1 = data["clip_l"]
embeds2 = data["clip_g"]
embeds1 = data["clip_l"] # text encoder 1
embeds2 = data["clip_g"] # text encoder 2
num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
assert (
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
)
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
@@ -1617,11 +1628,11 @@ def main(args):
assert (
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
), f"token ids2 is not ordered"
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
token_ids_embeds1.append((token_ids1, embeds1))