mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add workaround for clip's bug for pooled output
This commit is contained in:
@@ -18,7 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput,
|
|||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from library import sdxl_model_util, sdxl_train_util
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -210,7 +210,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
|
|||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, device):
|
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
|
||||||
if not is_sdxl_text_encoder2:
|
if not is_sdxl_text_encoder2:
|
||||||
# text_encoder1: same as SD1/2
|
# text_encoder1: same as SD1/2
|
||||||
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
||||||
@@ -220,7 +220,8 @@ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, devi
|
|||||||
# text_encoder2
|
# text_encoder2
|
||||||
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
||||||
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
pool = enc_out["text_embeds"]
|
# pool = enc_out["text_embeds"]
|
||||||
|
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
|
||||||
hidden_states = hidden_states.to(device)
|
hidden_states = hidden_states.to(device)
|
||||||
if pool is not None:
|
if pool is not None:
|
||||||
pool = pool.to(device)
|
pool = pool.to(device)
|
||||||
@@ -261,7 +262,7 @@ def get_unweighted_text_embeddings(
|
|||||||
text_input_chunk[j, 1] = eos
|
text_input_chunk[j, 1] = eos
|
||||||
|
|
||||||
text_embedding, current_text_pool = get_hidden_states(
|
text_embedding, current_text_pool = get_hidden_states(
|
||||||
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, pipe.device
|
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
|
||||||
)
|
)
|
||||||
if text_pool is None:
|
if text_pool is None:
|
||||||
text_pool = current_text_pool
|
text_pool = current_text_pool
|
||||||
@@ -280,7 +281,7 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embeddings.append(text_embedding)
|
text_embeddings.append(text_embedding)
|
||||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||||
else:
|
else:
|
||||||
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, pipe.device)
|
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
|
||||||
return text_embeddings, text_pool
|
return text_embeddings, text_pool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import torch
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
|||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def pool_workaround(
|
||||||
|
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||||
|
instead of the hidden states for the EOS token
|
||||||
|
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||||
|
|
||||||
|
Original code from CLIP's pooling function:
|
||||||
|
|
||||||
|
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||||
|
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
|
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# input_ids: b*n,77
|
||||||
|
# find index for EOS token
|
||||||
|
eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||||
|
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||||
|
print(eos_token_index)
|
||||||
|
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
|
||||||
|
|
||||||
|
# get hidden states for EOS token
|
||||||
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||||
|
|
||||||
|
# apply projection
|
||||||
|
pooled_output = text_encoder.text_projection(pooled_output)
|
||||||
|
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_states_sdxl(
|
def get_hidden_states_sdxl(
|
||||||
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
max_token_length: int,
|
||||||
|
input_ids1: torch.Tensor,
|
||||||
|
input_ids2: torch.Tensor,
|
||||||
|
tokenizer1: CLIPTokenizer,
|
||||||
|
tokenizer2: CLIPTokenizer,
|
||||||
|
text_encoder1: CLIPTextModel,
|
||||||
|
text_encoder2: CLIPTextModelWithProjection,
|
||||||
|
weight_dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# input_ids: b,n,77 -> b*n, 77
|
# input_ids: b,n,77 -> b*n, 77
|
||||||
b_size = input_ids1.size()[0]
|
b_size = input_ids1.size()[0]
|
||||||
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
|
|||||||
# text_encoder2
|
# text_encoder2
|
||||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
pool2 = enc_out["text_embeds"]
|
|
||||||
|
# pool2 = enc_out["text_embeds"]
|
||||||
|
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||||
|
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
|
||||||
|
|
||||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
replace_vae_attn_to_memory_efficient()
|
replace_vae_attn_to_memory_efficient()
|
||||||
elif xformers:
|
elif xformers:
|
||||||
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す?
|
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す?
|
||||||
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
|
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
|
||||||
elif sdpa:
|
elif sdpa:
|
||||||
replace_vae_attn_to_sdpa()
|
replace_vae_attn_to_sdpa()
|
||||||
|
|
||||||
@@ -960,6 +960,8 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embedding = enc_out["hidden_states"][-2]
|
text_embedding = enc_out["hidden_states"][-2]
|
||||||
if pool is None:
|
if pool is None:
|
||||||
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
||||||
|
if pool is not None:
|
||||||
|
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
|
||||||
|
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -978,6 +980,8 @@ def get_unweighted_text_embeddings(
|
|||||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||||
text_embeddings = enc_out["hidden_states"][-2]
|
text_embeddings = enc_out["hidden_states"][-2]
|
||||||
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
||||||
|
if pool is not None:
|
||||||
|
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
|
||||||
return text_embeddings, pool
|
return text_embeddings, pool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ if __name__ == "__main__":
|
|||||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||||
# print("hidden_states2", text_embedding2_penu.shape)
|
# print("hidden_states2", text_embedding2_penu.shape)
|
||||||
text_embedding2_pool = enc_out["text_embeds"]
|
text_embedding2_pool = enc_out["text_embeds"] # do not suport Textual Inversion
|
||||||
|
|
||||||
# 連結して終了 concat and finish
|
# 連結して終了 concat and finish
|
||||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user