mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Support Textual Inversion
This commit is contained in:
@@ -470,6 +470,9 @@ class PipelineLike():
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
# Textual Inversion
|
||||
self.token_replacements = {}
|
||||
|
||||
# CLIP guidance
|
||||
self.clip_guidance_scale = clip_guidance_scale
|
||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||
@@ -484,6 +487,19 @@ class PipelineLike():
|
||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
|
||||
def replace_token(self, tokens):
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
if token in self.token_replacements:
|
||||
new_tokens.extend(self.token_replacements[token])
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
|
||||
token = pipe.replace_token(token)
|
||||
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
@@ -2039,6 +2058,44 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds = []
|
||||
for embeds_file in args.textual_inversion_embeddings:
|
||||
if model_util.is_safetensors(embeds_file):
|
||||
from safetensors.torch import load_file
|
||||
data = load_file(embeds_file)
|
||||
else:
|
||||
data = torch.load(embeds_file, map_location="cpu")
|
||||
|
||||
embeds = next(iter(data.values()))
|
||||
if type(embeds) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||
|
||||
num_vectors_per_token = embeds.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
if num_vectors_per_token > 1:
|
||||
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||
|
||||
token_ids_embeds.append((token_ids, embeds))
|
||||
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
for token_ids, embeds in token_ids_embeds:
|
||||
for token_id, embed in zip(token_ids, embeds):
|
||||
token_embeds[token_id] = embed
|
||||
|
||||
# promptを取得する
|
||||
if args.from_file is not None:
|
||||
print(f"reading prompts from {args.from_file}")
|
||||
@@ -2157,8 +2214,8 @@ def main(args):
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||
|
||||
for iter in range(args.n_iter):
|
||||
print(f"iteration {iter+1}/{args.n_iter}")
|
||||
for gen_iter in range(args.n_iter):
|
||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||
iter_seed = random.randint(0, 0x7fffffff)
|
||||
|
||||
# バッチ処理の関数
|
||||
@@ -2527,6 +2584,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||
|
||||
Reference in New Issue
Block a user