fix not working

This commit is contained in:
Kohya S
2023-03-30 22:28:55 +09:00
parent 5fc80b7a5b
commit e76ea7cd7d

View File

@@ -536,6 +536,9 @@ class PipelineLike:
new_tokens.append(token) new_tokens.append(token)
return new_tokens return new_tokens
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
self.token_replacements_XTI[target_token_id] = rep_token_ids
def set_control_nets(self, ctrl_nets): def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets self.control_nets = ctrl_nets
@@ -779,7 +782,24 @@ class PipelineLike:
if self.token_replacements_XTI: if self.token_replacements_XTI:
text_embeddings_concat = [] text_embeddings_concat = []
for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']: for layer in [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self, pipe=self,
prompt=prompt, prompt=prompt,
@@ -801,14 +821,6 @@ class PipelineLike:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else: else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
# CLIP guidanceで使用するembeddingsを取得する # CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0: if self.clip_guidance_scale > 0:
@@ -1716,7 +1728,7 @@ def parse_prompt_attention(text):
return res return res
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
r""" r"""
Tokenize a list of prompts and return its tokens with weights of each token. Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included. No padding, starting or ending token is included.
@@ -1732,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token) token = pipe.replace_token(token, layer=layer)
text_token += token text_token += token
# copy the weight by length of token # copy the weight by length of token
@@ -1879,11 +1891,11 @@ def get_weighted_text_embeddings(
prompt = [prompt] prompt = [prompt]
if not skip_parsing: if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None: if uncond_prompt is not None:
if isinstance(uncond_prompt, str): if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt] uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
else: else:
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens] prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
@@ -2335,6 +2347,7 @@ def main(args):
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
@@ -2386,25 +2399,45 @@ def main(args):
token_embeds[token_id] = embed token_embeds[token_id] = embed
if args.XTI_embeddings: if args.XTI_embeddings:
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
token_ids_embeds_XTI = [] token_ids_embeds_XTI = []
for embeds_file in args.XTI_embeddings: for embeds_file in args.XTI_embeddings:
if model_util.is_safetensors(embeds_file): if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file from safetensors.torch import load_file
data = load_file(embeds_file) data = load_file(embeds_file)
else: else:
data = torch.load(embeds_file, map_location="cpu") data = torch.load(embeds_file, map_location="cpu")
if set(data.keys()) != set(XTI_layers): if set(data.keys()) != set(XTI_layers):
raise ValueError("NOT XTI") raise ValueError("NOT XTI")
embeds = torch.concat(list(data.values())) embeds = torch.concat(list(data.values()))
num_vectors_per_token = data['MID'].size()[0] num_vectors_per_token = data["MID"].size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token # add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings) num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" assert (
num_added_tokens == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
@@ -3090,8 +3123,8 @@ def setup_parser() -> argparse.ArgumentParser:
"--XTI_embeddings", "--XTI_embeddings",
type=str, type=str,
default=None, default=None,
nargs='*', nargs="*",
help='Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings' help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
) )
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument( parser.add_argument(