Fix lpwp to support sdv2 and clip skip

This commit is contained in:
Kohya S
2023-03-19 11:10:17 +09:00
parent cfb19ad0da
commit 1f7babd2c7
2 changed files with 105 additions and 82 deletions

View File

@@ -1,3 +1,6 @@
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
# and modify to support SD2.x
import inspect import inspect
import re import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
@@ -208,6 +211,9 @@ def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
text_input: torch.Tensor, text_input: torch.Tensor,
chunk_length: int, chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True, no_boseos_middle: Optional[bool] = True,
): ):
""" """
@@ -223,8 +229,26 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens # cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1] text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk,attention_mask=None)[0] else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = pipe.text_encoder(text_input_chunk)[0]
else:
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle: if no_boseos_middle:
if i == 0: if i == 0:
@@ -252,6 +276,7 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False, no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False, skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: Optional[bool] = False,
clip_skip=None,
): ):
r""" r"""
Prompts can be assigned with local weights using brackets. For example, Prompts can be assigned with local weights using brackets. For example,
@@ -289,16 +314,13 @@ def get_weighted_text_embeddings(
uncond_prompt = [uncond_prompt] uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
else: else:
prompt_tokens = [ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens] prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None: if uncond_prompt is not None:
if isinstance(uncond_prompt, str): if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt] uncond_prompt = [uncond_prompt]
uncond_tokens = [ uncond_tokens = [
token[1:-1] token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
] ]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens] uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
@@ -317,6 +339,7 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights # pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id eos = pipe.tokenizer.eos_token_id
pad = pipe.tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens, prompt_tokens,
prompt_weights, prompt_weights,
@@ -344,6 +367,9 @@ def get_weighted_text_embeddings(
pipe, pipe,
prompt_tokens, prompt_tokens,
pipe.tokenizer.model_max_length, pipe.tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
) )
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
@@ -352,6 +378,9 @@ def get_weighted_text_embeddings(
pipe, pipe,
uncond_tokens, uncond_tokens,
pipe.tokenizer.model_max_length, pipe.tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
) )
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
@@ -426,7 +455,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
def __init__( def __init__(
self, self,
@@ -435,6 +464,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: SchedulerMixin, scheduler: SchedulerMixin,
clip_skip: int,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
@@ -449,30 +479,30 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker, requires_safety_checker=requires_safety_checker,
) )
self.clip_skip = clip_skip
self.__init__additional__() self.__init__additional__()
else: # else:
# def __init__(
def __init__( # self,
self, # vae: AutoencoderKL,
vae: AutoencoderKL, # text_encoder: CLIPTextModel,
text_encoder: CLIPTextModel, # tokenizer: CLIPTokenizer,
tokenizer: CLIPTokenizer, # unet: UNet2DConditionModel,
unet: UNet2DConditionModel, # scheduler: SchedulerMixin,
scheduler: SchedulerMixin, # safety_checker: StableDiffusionSafetyChecker,
safety_checker: StableDiffusionSafetyChecker, # feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPFeatureExtractor, # ):
): # super().__init__(
super().__init__( # vae=vae,
vae=vae, # text_encoder=text_encoder,
text_encoder=text_encoder, # tokenizer=tokenizer,
tokenizer=tokenizer, # unet=unet,
unet=unet, # scheduler=scheduler,
scheduler=scheduler, # safety_checker=safety_checker,
safety_checker=safety_checker, # feature_extractor=feature_extractor,
feature_extractor=feature_extractor, # )
) # self.__init__additional__()
self.__init__additional__()
def __init__additional__(self): def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"): if not hasattr(self, "vae_scale_factor"):
@@ -541,6 +571,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt, prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
) )
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -562,15 +593,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
print(height,width) print(height, width)
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
): ):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
f" {type(callback_steps)}."
) )
def get_timesteps(self, num_inference_steps, strength, device, is_text2img): def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
@@ -589,9 +619,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None: if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else: else:
has_nsfw_concept = None has_nsfw_concept = None
return image, has_nsfw_concept return image, has_nsfw_concept

View File

@@ -2703,8 +2703,7 @@ def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
): ):
""" """
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけに対応していない StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応し
clip skipは対応した
""" """
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return return
@@ -2724,26 +2723,6 @@ def sample_images(
org_vae_device = vae.device # CPUにいるはず org_vae_device = vae.device # CPUにいるはず
vae.to(device) vae.to(device)
# clip skip 対応のための wrapper を作る
if args.clip_skip is None:
text_encoder_or_wrapper = text_encoder
else:
class Wrapper:
def __init__(self, tenc) -> None:
self.tenc = tenc
self.config = {}
super().__init__()
def __call__(self, input_ids, attention_mask):
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
pooled_output = enc_out["pooler_output"]
return encoder_hidden_states, pooled_output # 1st output is only used
text_encoder_or_wrapper = Wrapper(text_encoder)
# read prompts # read prompts
with open(args.sample_prompts, "rt", encoding="utf-8") as f: with open(args.sample_prompts, "rt", encoding="utf-8") as f:
prompts = f.readlines() prompts = f.readlines()
@@ -2792,8 +2771,17 @@ def sample_images(
# print("set clip_sample to True") # print("set clip_sample to True")
scheduler.config.clip_sample = True scheduler.config.clip_sample = True
pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, pipeline = StableDiffusionLongPromptWeightingPipeline(
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipeline.to(device) pipeline.to(device)
save_dir = args.output_dir + "/sample" save_dir = args.output_dir + "/sample"
@@ -2872,7 +2860,14 @@ def sample_images(
print(f"width: {width}") print(f"width: {width}")
print(f"sample_steps: {sample_steps}") print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}") print(f"scale: {scale}")
image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0] image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"