Add workaround for clip's bug for pooled output

This commit is contained in:
Kohya S
2023-08-04 08:38:27 +09:00
parent cf6832896f
commit c6d52fdea4
4 changed files with 61 additions and 11 deletions

View File

@@ -34,7 +34,7 @@ import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
@@ -2868,7 +2868,7 @@ def verify_training_args(args: argparse.Namespace):
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
if args.v_pred_like_loss and args.v_parameterization:
raise ValueError(
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
return encoder_hidden_states
def pool_workaround(
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
eos_token_index = torch.where(input_ids == eos_token_id)[1]
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
print(eos_token_index)
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
# apply projection
pooled_output = text_encoder.text_projection(pooled_output)
return pooled_output
def get_hidden_states_sdxl(
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"]
# pool2 = enc_out["text_embeds"]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75