Add workaround for clip's bug for pooled output

This commit is contained in:
Kohya S
2023-08-04 08:38:27 +09:00
parent cf6832896f
commit c6d52fdea4
4 changed files with 61 additions and 11 deletions

View File

@@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
replace_vae_attn_to_memory_efficient()
elif xformers:
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
elif sdpa:
replace_vae_attn_to_sdpa()
@@ -960,6 +960,8 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-2]
if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
if no_boseos_middle:
if i == 0:
@@ -978,6 +980,8 @@ def get_unweighted_text_embeddings(
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-2]
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
return text_embeddings, pool