mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Support Textual Inversion
This commit is contained in:
@@ -470,6 +470,9 @@ class PipelineLike():
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
self.token_replacements = {}
|
||||||
|
|
||||||
# CLIP guidance
|
# CLIP guidance
|
||||||
self.clip_guidance_scale = clip_guidance_scale
|
self.clip_guidance_scale = clip_guidance_scale
|
||||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||||
@@ -484,6 +487,19 @@ class PipelineLike():
|
|||||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||||
|
self.token_replacements[target_token_id] = rep_token_ids
|
||||||
|
|
||||||
|
def replace_token(self, tokens):
|
||||||
|
new_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if token in self.token_replacements:
|
||||||
|
new_tokens.extend(self.token_replacements[token])
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
@@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
for word, weight in texts_and_weights:
|
for word, weight in texts_and_weights:
|
||||||
# tokenize and discard the starting and the ending token
|
# tokenize and discard the starting and the ending token
|
||||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
|
||||||
|
token = pipe.replace_token(token)
|
||||||
|
|
||||||
text_token += token
|
text_token += token
|
||||||
# copy the weight by length of token
|
# copy the weight by length of token
|
||||||
text_weight += [weight] * len(token)
|
text_weight += [weight] * len(token)
|
||||||
@@ -2039,6 +2058,44 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Textual Inversionを処理する
|
||||||
|
if args.textual_inversion_embeddings:
|
||||||
|
token_ids_embeds = []
|
||||||
|
for embeds_file in args.textual_inversion_embeddings:
|
||||||
|
if model_util.is_safetensors(embeds_file):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
data = load_file(embeds_file)
|
||||||
|
else:
|
||||||
|
data = torch.load(embeds_file, map_location="cpu")
|
||||||
|
|
||||||
|
embeds = next(iter(data.values()))
|
||||||
|
if type(embeds) != torch.Tensor:
|
||||||
|
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||||
|
|
||||||
|
num_vectors_per_token = embeds.size()[0]
|
||||||
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
|
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
|
if num_vectors_per_token > 1:
|
||||||
|
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||||
|
|
||||||
|
token_ids_embeds.append((token_ids, embeds))
|
||||||
|
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
for token_ids, embeds in token_ids_embeds:
|
||||||
|
for token_id, embed in zip(token_ids, embeds):
|
||||||
|
token_embeds[token_id] = embed
|
||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
print(f"reading prompts from {args.from_file}")
|
||||||
@@ -2157,8 +2214,8 @@ def main(args):
|
|||||||
os.makedirs(args.outdir, exist_ok=True)
|
os.makedirs(args.outdir, exist_ok=True)
|
||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {iter+1}/{args.n_iter}")
|
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7fffffff)
|
iter_seed = random.randint(0, 0x7fffffff)
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
@@ -2527,6 +2584,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
|
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||||
|
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||||
|
|||||||
Reference in New Issue
Block a user