diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 19c63acf..8b29e78e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -470,6 +470,9 @@ class PipelineLike(): self.scheduler = scheduler self.safety_checker = None + # Textual Inversion + self.token_replacements = {} + # CLIP guidance self.clip_guidance_scale = clip_guidance_scale self.clip_image_guidance_scale = clip_image_guidance_scale @@ -484,6 +487,19 @@ class PipelineLike(): self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + # Textual Inversion + def add_token_replacement(self, target_token_id, rep_token_ids): + self.token_replacements[target_token_id] = rep_token_ids + + def replace_token(self, tokens): + new_tokens = [] + for token in tokens: + if token in self.token_replacements: + new_tokens.extend(self.token_replacements[token]) + else: + new_tokens.append(token) + return new_tokens + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): r""" @@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: for word, weight in texts_and_weights: # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] + + token = pipe.replace_token(token) + text_token += token # copy the weight by length of token text_weight += [weight] * len(token) @@ -2039,6 +2058,44 @@ def main(args): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + embeds = next(iter(data.values())) + if type(embeds) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + + num_vectors_per_token = embeds.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_ids_embeds.append((token_ids, embeds)) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + # promptを取得する if args.from_file is not None: print(f"reading prompts from {args.from_file}") @@ -2157,8 +2214,8 @@ def main(args): os.makedirs(args.outdir, exist_ok=True) max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - for iter in range(args.n_iter): - print(f"iteration {iter+1}/{args.n_iter}") + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7fffffff) # バッチ処理の関数 @@ -2527,6 +2584,8 @@ if __name__ == '__main__': parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') + parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', + help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')