From 188e54b7605df4e488a21aadff73aa6cfd6169df Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 11 Feb 2023 15:00:11 +0900 Subject: [PATCH] support multiple init words --- train_textual_inversion.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ba2e7145..4aa91eee 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -98,12 +98,12 @@ def train(args): # Convert the init_word to token_id if args.init_word is not None: - init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False) - assert len( - init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください" - init_token_id = init_token_id[0] + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") else: - init_token_id = None + init_token_ids = None # add new word to tokenizer, count is num_vectors_per_token token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] @@ -120,9 +120,9 @@ def train(args): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_id is not None: - for token_id in token_ids: - token_embeds[token_id] = token_embeds[init_token_id] + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights @@ -492,7 +492,7 @@ if __name__ == '__main__': parser.add_argument("--token_string", type=str, default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") parser.add_argument("--init_word", type=str, default=None, - help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること") + help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") parser.add_argument("--use_object_template", action='store_true', help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") parser.add_argument("--use_style_template", action='store_true',