mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix lpwp to support sdv2 and clip skip
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
||||||
|
# and modify to support SD2.x
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
@@ -208,6 +211,9 @@ def get_unweighted_text_embeddings(
|
|||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
text_input: torch.Tensor,
|
text_input: torch.Tensor,
|
||||||
chunk_length: int,
|
chunk_length: int,
|
||||||
|
clip_skip: int,
|
||||||
|
eos: int,
|
||||||
|
pad: int,
|
||||||
no_boseos_middle: Optional[bool] = True,
|
no_boseos_middle: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -221,10 +227,28 @@ def get_unweighted_text_embeddings(
|
|||||||
# extract the i-th chunk
|
# extract the i-th chunk
|
||||||
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
||||||
|
|
||||||
|
# cover the head and the tail by the starting and the ending tokens
|
||||||
|
text_input_chunk[:, 0] = text_input[0, 0]
|
||||||
|
if pad == eos: # v1
|
||||||
|
text_input_chunk[:, -1] = text_input[0, -1]
|
||||||
|
else: # v2
|
||||||
|
for j in range(len(text_input_chunk)):
|
||||||
|
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||||||
|
text_input_chunk[j, -1] = eos
|
||||||
|
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||||||
|
text_input_chunk[j, 1] = eos
|
||||||
|
|
||||||
|
if clip_skip is None or clip_skip == 1:
|
||||||
|
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||||
|
else:
|
||||||
|
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||||
|
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||||
|
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||||
|
|
||||||
# cover the head and the tail by the starting and the ending tokens
|
# cover the head and the tail by the starting and the ending tokens
|
||||||
text_input_chunk[:, 0] = text_input[0, 0]
|
text_input_chunk[:, 0] = text_input[0, 0]
|
||||||
text_input_chunk[:, -1] = text_input[0, -1]
|
text_input_chunk[:, -1] = text_input[0, -1]
|
||||||
text_embedding = pipe.text_encoder(text_input_chunk,attention_mask=None)[0]
|
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||||
|
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -252,6 +276,7 @@ def get_weighted_text_embeddings(
|
|||||||
no_boseos_middle: Optional[bool] = False,
|
no_boseos_middle: Optional[bool] = False,
|
||||||
skip_parsing: Optional[bool] = False,
|
skip_parsing: Optional[bool] = False,
|
||||||
skip_weighting: Optional[bool] = False,
|
skip_weighting: Optional[bool] = False,
|
||||||
|
clip_skip=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prompts can be assigned with local weights using brackets. For example,
|
Prompts can be assigned with local weights using brackets. For example,
|
||||||
@@ -289,16 +314,13 @@ def get_weighted_text_embeddings(
|
|||||||
uncond_prompt = [uncond_prompt]
|
uncond_prompt = [uncond_prompt]
|
||||||
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
||||||
else:
|
else:
|
||||||
prompt_tokens = [
|
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
||||||
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
|
||||||
]
|
|
||||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||||
if uncond_prompt is not None:
|
if uncond_prompt is not None:
|
||||||
if isinstance(uncond_prompt, str):
|
if isinstance(uncond_prompt, str):
|
||||||
uncond_prompt = [uncond_prompt]
|
uncond_prompt = [uncond_prompt]
|
||||||
uncond_tokens = [
|
uncond_tokens = [
|
||||||
token[1:-1]
|
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
||||||
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
|
||||||
]
|
]
|
||||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||||
|
|
||||||
@@ -317,6 +339,7 @@ def get_weighted_text_embeddings(
|
|||||||
# pad the length of tokens and weights
|
# pad the length of tokens and weights
|
||||||
bos = pipe.tokenizer.bos_token_id
|
bos = pipe.tokenizer.bos_token_id
|
||||||
eos = pipe.tokenizer.eos_token_id
|
eos = pipe.tokenizer.eos_token_id
|
||||||
|
pad = pipe.tokenizer.pad_token_id
|
||||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
prompt_weights,
|
prompt_weights,
|
||||||
@@ -344,6 +367,9 @@ def get_weighted_text_embeddings(
|
|||||||
pipe,
|
pipe,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
pipe.tokenizer.model_max_length,
|
pipe.tokenizer.model_max_length,
|
||||||
|
clip_skip,
|
||||||
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
)
|
)
|
||||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||||
@@ -352,6 +378,9 @@ def get_weighted_text_embeddings(
|
|||||||
pipe,
|
pipe,
|
||||||
uncond_tokens,
|
uncond_tokens,
|
||||||
pipe.tokenizer.model_max_length,
|
pipe.tokenizer.model_max_length,
|
||||||
|
clip_skip,
|
||||||
|
eos,
|
||||||
|
pad,
|
||||||
no_boseos_middle=no_boseos_middle,
|
no_boseos_middle=no_boseos_middle,
|
||||||
)
|
)
|
||||||
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
||||||
@@ -426,53 +455,54 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: SchedulerMixin,
|
scheduler: SchedulerMixin,
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
clip_skip: int,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
requires_safety_checker: bool = True,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
):
|
requires_safety_checker: bool = True,
|
||||||
super().__init__(
|
):
|
||||||
vae=vae,
|
super().__init__(
|
||||||
text_encoder=text_encoder,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
text_encoder=text_encoder,
|
||||||
unet=unet,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
unet=unet,
|
||||||
safety_checker=safety_checker,
|
scheduler=scheduler,
|
||||||
feature_extractor=feature_extractor,
|
safety_checker=safety_checker,
|
||||||
requires_safety_checker=requires_safety_checker,
|
feature_extractor=feature_extractor,
|
||||||
)
|
requires_safety_checker=requires_safety_checker,
|
||||||
self.__init__additional__()
|
)
|
||||||
|
self.clip_skip = clip_skip
|
||||||
|
self.__init__additional__()
|
||||||
|
|
||||||
else:
|
# else:
|
||||||
|
# def __init__(
|
||||||
def __init__(
|
# self,
|
||||||
self,
|
# vae: AutoencoderKL,
|
||||||
vae: AutoencoderKL,
|
# text_encoder: CLIPTextModel,
|
||||||
text_encoder: CLIPTextModel,
|
# tokenizer: CLIPTokenizer,
|
||||||
tokenizer: CLIPTokenizer,
|
# unet: UNet2DConditionModel,
|
||||||
unet: UNet2DConditionModel,
|
# scheduler: SchedulerMixin,
|
||||||
scheduler: SchedulerMixin,
|
# safety_checker: StableDiffusionSafetyChecker,
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
# feature_extractor: CLIPFeatureExtractor,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
# ):
|
||||||
):
|
# super().__init__(
|
||||||
super().__init__(
|
# vae=vae,
|
||||||
vae=vae,
|
# text_encoder=text_encoder,
|
||||||
text_encoder=text_encoder,
|
# tokenizer=tokenizer,
|
||||||
tokenizer=tokenizer,
|
# unet=unet,
|
||||||
unet=unet,
|
# scheduler=scheduler,
|
||||||
scheduler=scheduler,
|
# safety_checker=safety_checker,
|
||||||
safety_checker=safety_checker,
|
# feature_extractor=feature_extractor,
|
||||||
feature_extractor=feature_extractor,
|
# )
|
||||||
)
|
# self.__init__additional__()
|
||||||
self.__init__additional__()
|
|
||||||
|
|
||||||
def __init__additional__(self):
|
def __init__additional__(self):
|
||||||
if not hasattr(self, "vae_scale_factor"):
|
if not hasattr(self, "vae_scale_factor"):
|
||||||
@@ -541,6 +571,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
|
clip_skip=self.clip_skip,
|
||||||
)
|
)
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||||
@@ -562,15 +593,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||||
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
print(height,width)
|
print(height, width)
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
if (callback_steps is None) or (
|
if (callback_steps is None) or (
|
||||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
||||||
f" {type(callback_steps)}."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
||||||
@@ -589,9 +619,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
def run_safety_checker(self, image, device, dtype):
|
def run_safety_checker(self, image, device, dtype):
|
||||||
if self.safety_checker is not None:
|
if self.safety_checker is not None:
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
has_nsfw_concept = None
|
has_nsfw_concept = None
|
||||||
return image, has_nsfw_concept
|
return image, has_nsfw_concept
|
||||||
|
|||||||
@@ -2703,8 +2703,7 @@ def sample_images(
|
|||||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
|
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
clip skipは対応した
|
|
||||||
"""
|
"""
|
||||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||||
return
|
return
|
||||||
@@ -2724,26 +2723,6 @@ def sample_images(
|
|||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
# clip skip 対応のための wrapper を作る
|
|
||||||
if args.clip_skip is None:
|
|
||||||
text_encoder_or_wrapper = text_encoder
|
|
||||||
else:
|
|
||||||
|
|
||||||
class Wrapper:
|
|
||||||
def __init__(self, tenc) -> None:
|
|
||||||
self.tenc = tenc
|
|
||||||
self.config = {}
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask):
|
|
||||||
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
|
||||||
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
|
|
||||||
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
|
||||||
pooled_output = enc_out["pooler_output"]
|
|
||||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
|
||||||
|
|
||||||
text_encoder_or_wrapper = Wrapper(text_encoder)
|
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
prompts = f.readlines()
|
prompts = f.readlines()
|
||||||
@@ -2792,8 +2771,17 @@ def sample_images(
|
|||||||
# print("set clip_sample to True")
|
# print("set clip_sample to True")
|
||||||
scheduler.config.clip_sample = True
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||||
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
clip_skip=args.clip_skip,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
save_dir = args.output_dir + "/sample"
|
save_dir = args.output_dir + "/sample"
|
||||||
@@ -2872,7 +2860,14 @@ def sample_images(
|
|||||||
print(f"width: {width}")
|
print(f"width: {width}")
|
||||||
print(f"sample_steps: {sample_steps}")
|
print(f"sample_steps: {sample_steps}")
|
||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0]
|
image = pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=sample_steps,
|
||||||
|
guidance_scale=scale,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
|
|||||||
Reference in New Issue
Block a user