mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support textual inversion training
This commit is contained in:
@@ -320,7 +320,7 @@ class PipelineLike:
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
# Textual Inversion # not tested yet
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
self.token_replacements_list.append({})
|
||||
@@ -341,6 +341,10 @@ class PipelineLike:
|
||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||
|
||||
def replace_tokens(tokens):
|
||||
# print("replace_tokens", tokens, "=>", token_replacements)
|
||||
if isinstance(tokens, torch.Tensor):
|
||||
tokens = tokens.tolist()
|
||||
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
if token in token_replacements:
|
||||
@@ -1594,19 +1598,26 @@ def main(args):
|
||||
|
||||
if "string_to_param" in data:
|
||||
data = data["string_to_param"]
|
||||
embeds1 = data["clip_l"]
|
||||
embeds2 = data["clip_g"]
|
||||
|
||||
embeds1 = data["clip_l"] # text encoder 1
|
||||
embeds2 = data["clip_g"] # text encoder 2
|
||||
|
||||
num_vectors_per_token = embeds1.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||
# TODO make random alphabet string
|
||||
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||
|
||||
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
|
||||
assert (
|
||||
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
|
||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||
)
|
||||
|
||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||
@@ -1617,11 +1628,11 @@ def main(args):
|
||||
assert (
|
||||
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
||||
), f"token ids2 is not ordered"
|
||||
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
|
||||
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
|
||||
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
|
||||
|
||||
if num_vectors_per_token > 1:
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
|
||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
|
||||
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
||||
|
||||
token_ids_embeds1.append((token_ids1, embeds1))
|
||||
|
||||
Reference in New Issue
Block a user