mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix pad token is not handled
This commit is contained in:
@@ -185,14 +185,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
|
|||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
||||||
r"""
|
r"""
|
||||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||||
"""
|
"""
|
||||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||||
for i in range(len(tokens)):
|
for i in range(len(tokens)):
|
||||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||||
else:
|
else:
|
||||||
@@ -363,6 +363,7 @@ def get_weighted_text_embeddings(
|
|||||||
max_length,
|
max_length,
|
||||||
bos,
|
bos,
|
||||||
eos,
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
chunk_length=pipe.tokenizer.model_max_length,
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
)
|
)
|
||||||
@@ -374,6 +375,7 @@ def get_weighted_text_embeddings(
|
|||||||
max_length,
|
max_length,
|
||||||
bos,
|
bos,
|
||||||
eos,
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
chunk_length=pipe.tokenizer.model_max_length,
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
)
|
)
|
||||||
@@ -711,7 +713,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
||||||
# image = self.vae.decode(latents.to("cpu")).sample
|
# image = self.vae.decode(latents.to("cpu")).sample
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|||||||
Reference in New Issue
Block a user