mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix lpwp to support sdv2 and clip skip
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
||||
# and modify to support SD2.x
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
@@ -208,6 +211,9 @@ def get_unweighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
clip_skip: int,
|
||||
eos: int,
|
||||
pad: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
@@ -221,10 +227,28 @@ def get_unweighted_text_embeddings(
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
if pad == eos: # v1
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
else: # v2
|
||||
for j in range(len(text_input_chunk)):
|
||||
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||||
text_input_chunk[j, -1] = eos
|
||||
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||||
text_input_chunk[j, 1] = eos
|
||||
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||
else:
|
||||
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk,attention_mask=None)[0]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
@@ -252,6 +276,7 @@ def get_weighted_text_embeddings(
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
clip_skip=None,
|
||||
):
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
@@ -289,16 +314,13 @@ def get_weighted_text_embeddings(
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
||||
else:
|
||||
prompt_tokens = [
|
||||
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
||||
]
|
||||
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
||||
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
||||
]
|
||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||
|
||||
@@ -317,6 +339,7 @@ def get_weighted_text_embeddings(
|
||||
# pad the length of tokens and weights
|
||||
bos = pipe.tokenizer.bos_token_id
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
pad = pipe.tokenizer.pad_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
@@ -344,6 +367,9 @@ def get_weighted_text_embeddings(
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
clip_skip,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||
@@ -352,6 +378,9 @@ def get_weighted_text_embeddings(
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
clip_skip,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
||||
@@ -426,53 +455,54 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.__init__additional__()
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
clip_skip: int,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.clip_skip = clip_skip
|
||||
self.__init__additional__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.__init__additional__()
|
||||
# else:
|
||||
# def __init__(
|
||||
# self,
|
||||
# vae: AutoencoderKL,
|
||||
# text_encoder: CLIPTextModel,
|
||||
# tokenizer: CLIPTokenizer,
|
||||
# unet: UNet2DConditionModel,
|
||||
# scheduler: SchedulerMixin,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPFeatureExtractor,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vae=vae,
|
||||
# text_encoder=text_encoder,
|
||||
# tokenizer=tokenizer,
|
||||
# unet=unet,
|
||||
# scheduler=scheduler,
|
||||
# safety_checker=safety_checker,
|
||||
# feature_extractor=feature_extractor,
|
||||
# )
|
||||
# self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
if not hasattr(self, "vae_scale_factor"):
|
||||
@@ -541,6 +571,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -562,15 +593,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
print(height,width)
|
||||
print(height, width)
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
||||
@@ -589,9 +619,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
@@ -2703,8 +2703,7 @@ def sample_images(
|
||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
|
||||
):
|
||||
"""
|
||||
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
||||
clip skipは対応した
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
@@ -2724,26 +2723,6 @@ def sample_images(
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# clip skip 対応のための wrapper を作る
|
||||
if args.clip_skip is None:
|
||||
text_encoder_or_wrapper = text_encoder
|
||||
else:
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, tenc) -> None:
|
||||
self.tenc = tenc
|
||||
self.config = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids, attention_mask):
|
||||
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
|
||||
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
||||
pooled_output = enc_out["pooler_output"]
|
||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
||||
|
||||
text_encoder_or_wrapper = Wrapper(text_encoder)
|
||||
|
||||
# read prompts
|
||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
prompts = f.readlines()
|
||||
@@ -2792,8 +2771,17 @@ def sample_images(
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
||||
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
clip_skip=args.clip_skip,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
@@ -2872,7 +2860,14 @@ def sample_images(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0]
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
).images[0]
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
|
||||
Reference in New Issue
Block a user