Fix lpwp to support sdv2 and clip skip

This commit is contained in:
Kohya S
2023-03-19 11:10:17 +09:00
parent cfb19ad0da
commit 1f7babd2c7
2 changed files with 105 additions and 82 deletions

View File

@@ -1,3 +1,6 @@
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
# and modify to support SD2.x
import inspect
import re
from typing import Callable, List, Optional, Union
@@ -208,6 +211,9 @@ def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
@@ -221,10 +227,28 @@ def get_unweighted_text_embeddings(
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = pipe.text_encoder(text_input_chunk)[0]
else:
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk,attention_mask=None)[0]
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
@@ -252,6 +276,7 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
@@ -289,16 +314,13 @@ def get_weighted_text_embeddings(
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
else:
prompt_tokens = [
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
]
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
@@ -317,6 +339,7 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
pad = pipe.tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
@@ -344,6 +367,9 @@ def get_weighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
@@ -352,6 +378,9 @@ def get_weighted_text_embeddings(
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
@@ -426,53 +455,54 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.__init__additional__()
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.clip_skip = clip_skip
self.__init__additional__()
else:
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
@@ -541,6 +571,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -562,15 +593,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
print(height,width)
print(height, width)
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
)
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
@@ -589,9 +619,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
else:
has_nsfw_concept = None
return image, has_nsfw_concept