mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add workaround for clip's bug for pooled output
This commit is contained in:
@@ -34,7 +34,7 @@ import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
@@ -2868,7 +2868,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
|
||||
if args.v_pred_like_loss and args.v_parameterization:
|
||||
raise ValueError(
|
||||
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
|
||||
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def pool_workaround(
|
||||
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||
):
|
||||
r"""
|
||||
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||
instead of the hidden states for the EOS token
|
||||
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||
|
||||
Original code from CLIP's pooling function:
|
||||
|
||||
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
"""
|
||||
|
||||
# input_ids: b*n,77
|
||||
# find index for EOS token
|
||||
eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
print(eos_token_index)
|
||||
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
|
||||
|
||||
# get hidden states for EOS token
|
||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||
|
||||
# apply projection
|
||||
pooled_output = text_encoder.text_projection(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
def get_hidden_states_sdxl(
|
||||
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||
max_token_length: int,
|
||||
input_ids1: torch.Tensor,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer1: CLIPTokenizer,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder1: CLIPTextModel,
|
||||
text_encoder2: CLIPTextModelWithProjection,
|
||||
weight_dtype: Optional[str] = None,
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["text_embeds"]
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
|
||||
Reference in New Issue
Block a user