mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
multi embed training
This commit is contained in:
@@ -7,10 +7,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -167,6 +170,13 @@ class TextualInversionTrainer:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
|
||||
assert (
|
||||
args.token_string is not None or args.token_strings is not None
|
||||
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
|
||||
assert (
|
||||
not use_template or args.token_strings is None
|
||||
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
@@ -215,9 +225,17 @@ class TextualInversionTrainer:
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
if args.token_strings is not None:
|
||||
token_strings = args.token_strings
|
||||
assert (
|
||||
len(token_strings) == args.num_vectors_per_token
|
||||
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
|
||||
for token_string in token_strings:
|
||||
self.assert_token_string(token_string, tokenizers)
|
||||
else:
|
||||
self.assert_token_string(args.token_string, tokenizers)
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
token_ids_list = []
|
||||
token_embeds_list = []
|
||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||
@@ -332,7 +350,7 @@ class TextualInversionTrainer:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
# サンプル生成用
|
||||
if args.num_vectors_per_token > 1:
|
||||
if args.num_vectors_per_token > 1 and args.token_strings is None:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
@@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_strings",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
|
||||
Reference in New Issue
Block a user