multi embed training

This commit is contained in:
Kohya S
2023-10-10 23:36:05 +09:00
parent 33ee0acd35
commit f8629e3c1a
2 changed files with 112 additions and 3 deletions

View File

@@ -7,10 +7,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -167,6 +170,13 @@ class TextualInversionTrainer:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
assert (
args.token_string is not None or args.token_strings is not None
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
assert (
not use_template or args.token_strings is None
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -215,9 +225,17 @@ class TextualInversionTrainer:
# add new word to tokenizer, count is num_vectors_per_token
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
self.assert_token_string(args.token_string, tokenizers)
if args.token_strings is not None:
token_strings = args.token_strings
assert (
len(token_strings) == args.num_vectors_per_token
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
for token_string in token_strings:
self.assert_token_string(token_string, tokenizers)
else:
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
@@ -332,7 +350,7 @@ class TextualInversionTrainer:
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1:
if args.num_vectors_per_token > 1 and args.token_strings is None:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
@@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument(
"--token_strings",
type=str,
default=None,
nargs="*",
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",