From a380502c01a9d99281f9ea0d7486ac2b03c7147c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 18:13:49 +0900 Subject: [PATCH] fix pad token is not handled --- library/sdxl_lpw_stable_diffusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e806bc61..10038e6a 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -185,14 +185,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m return tokens, weights -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) if no_boseos_middle: weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: @@ -363,6 +363,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -374,6 +375,7 @@ def get_weighted_text_embeddings( max_length, bos, eos, + pad, no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) @@ -711,7 +713,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: # self.vae.set_use_memory_efficient_attention_xformers(False) # image = self.vae.decode(latents.to("cpu")).sample - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy()