Add workaround for clip's bug for pooled output

This commit is contained in:
Kohya S
2023-08-04 08:38:27 +09:00
parent cf6832896f
commit c6d52fdea4
4 changed files with 61 additions and 11 deletions

View File

@@ -18,7 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput,
from diffusers.utils import logging from diffusers.utils import logging
from PIL import Image from PIL import Image
from library import sdxl_model_util, sdxl_train_util from library import sdxl_model_util, sdxl_train_util, train_util
try: try:
@@ -210,7 +210,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
return tokens, weights return tokens, weights
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, device): def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
if not is_sdxl_text_encoder2: if not is_sdxl_text_encoder2:
# text_encoder1: same as SD1/2 # text_encoder1: same as SD1/2
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
@@ -220,7 +220,8 @@ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, devi
# text_encoder2 # text_encoder2
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
pool = enc_out["text_embeds"] # pool = enc_out["text_embeds"]
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
hidden_states = hidden_states.to(device) hidden_states = hidden_states.to(device)
if pool is not None: if pool is not None:
pool = pool.to(device) pool = pool.to(device)
@@ -261,7 +262,7 @@ def get_unweighted_text_embeddings(
text_input_chunk[j, 1] = eos text_input_chunk[j, 1] = eos
text_embedding, current_text_pool = get_hidden_states( text_embedding, current_text_pool = get_hidden_states(
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, pipe.device pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
) )
if text_pool is None: if text_pool is None:
text_pool = current_text_pool text_pool = current_text_pool
@@ -280,7 +281,7 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding) text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1) text_embeddings = torch.concat(text_embeddings, axis=1)
else: else:
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, pipe.device) text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
return text_embeddings, text_pool return text_embeddings, text_pool

View File

@@ -34,7 +34,7 @@ import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import ( from diffusers import (
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
return encoder_hidden_states return encoder_hidden_states
def pool_workaround(
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
eos_token_index = torch.where(input_ids == eos_token_id)[1]
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
print(eos_token_index)
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
# apply projection
pooled_output = text_encoder.text_projection(pooled_output)
return pooled_output
def get_hidden_states_sdxl( def get_hidden_states_sdxl(
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
): ):
# input_ids: b,n,77 -> b*n, 77 # input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0] b_size = input_ids1.size()[0]
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
# text_encoder2 # text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"]
# pool2 = enc_out["text_embeds"]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75 n_size = 1 if max_token_length is None else max_token_length // 75

View File

@@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
replace_vae_attn_to_memory_efficient() replace_vae_attn_to_memory_efficient()
elif xformers: elif xformers:
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
elif sdpa: elif sdpa:
replace_vae_attn_to_sdpa() replace_vae_attn_to_sdpa()
@@ -960,6 +960,8 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-2] text_embedding = enc_out["hidden_states"][-2]
if pool is None: if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
if no_boseos_middle: if no_boseos_middle:
if i == 0: if i == 0:
@@ -978,6 +980,8 @@ def get_unweighted_text_embeddings(
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-2] text_embeddings = enc_out["hidden_states"][-2]
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
return text_embeddings, pool return text_embeddings, pool

View File

@@ -213,7 +213,7 @@ if __name__ == "__main__":
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2] text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape) # print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] text_embedding2_pool = enc_out["text_embeds"] # do not suport Textual Inversion
# 連結して終了 concat and finish # 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)