diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 1de04237..3e04b887 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -1,3 +1,6 @@ +# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# and modify to support SD2.x + import inspect import re from typing import Callable, List, Optional, Union @@ -208,6 +211,9 @@ def get_unweighted_text_embeddings( pipe: StableDiffusionPipeline, text_input: torch.Tensor, chunk_length: int, + clip_skip: int, + eos: int, + pad: int, no_boseos_middle: Optional[bool] = True, ): """ @@ -221,10 +227,28 @@ def get_unweighted_text_embeddings( # extract the i-th chunk text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.text_encoder(text_input_chunk,attention_mask=None)[0] + text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0] if no_boseos_middle: if i == 0: @@ -252,6 +276,7 @@ def get_weighted_text_embeddings( no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, + clip_skip=None, ): r""" Prompts can be assigned with local weights using brackets. For example, @@ -289,16 +314,13 @@ def get_weighted_text_embeddings( uncond_prompt = [uncond_prompt] uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) else: - prompt_tokens = [ - token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids - ] + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_weights = [[1.0] * len(token) for token in prompt_tokens] if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens = [ - token[1:-1] - for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids ] uncond_weights = [[1.0] * len(token) for token in uncond_tokens] @@ -317,6 +339,7 @@ def get_weighted_text_embeddings( # pad the length of tokens and weights bos = pipe.tokenizer.bos_token_id eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, @@ -344,6 +367,9 @@ def get_weighted_text_embeddings( pipe, prompt_tokens, pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) @@ -352,6 +378,9 @@ def get_weighted_text_embeddings( pipe, uncond_tokens, pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, no_boseos_middle=no_boseos_middle, ) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) @@ -426,53 +455,54 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) - self.__init__additional__() + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + clip_skip: int, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.clip_skip = clip_skip + self.__init__additional__() - else: - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.__init__additional__() + # else: + # def __init__( + # self, + # vae: AutoencoderKL, + # text_encoder: CLIPTextModel, + # tokenizer: CLIPTokenizer, + # unet: UNet2DConditionModel, + # scheduler: SchedulerMixin, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + # ): + # super().__init__( + # vae=vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + # unet=unet, + # scheduler=scheduler, + # safety_checker=safety_checker, + # feature_extractor=feature_extractor, + # ) + # self.__init__additional__() def __init__additional__(self): if not hasattr(self, "vae_scale_factor"): @@ -541,6 +571,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, ) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) @@ -562,15 +593,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height,width) + print(height, width) raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def get_timesteps(self, num_inference_steps, strength, device, is_text2img): @@ -589,9 +619,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) else: has_nsfw_concept = None return image, has_nsfw_concept diff --git a/library/train_util.py b/library/train_util.py index 99d0a06a..80d7fa47 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2703,8 +2703,7 @@ def sample_images( accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None ): """ - 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない - clip skipは対応した + StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return @@ -2724,26 +2723,6 @@ def sample_images( org_vae_device = vae.device # CPUにいるはず vae.to(device) - # clip skip 対応のための wrapper を作る - if args.clip_skip is None: - text_encoder_or_wrapper = text_encoder - else: - - class Wrapper: - def __init__(self, tenc) -> None: - self.tenc = tenc - self.config = {} - super().__init__() - - def __call__(self, input_ids, attention_mask): - enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] - encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) - pooled_output = enc_out["pooler_output"] - return encoder_hidden_states, pooled_output # 1st output is only used - - text_encoder_or_wrapper = Wrapper(text_encoder) - # read prompts with open(args.sample_prompts, "rt", encoding="utf-8") as f: prompts = f.readlines() @@ -2792,8 +2771,17 @@ def sample_images( # print("set clip_sample to True") scheduler.config.clip_sample = True - pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) + pipeline = StableDiffusionLongPromptWeightingPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + clip_skip=args.clip_skip, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) pipeline.to(device) save_dir = args.output_dir + "/sample" @@ -2872,7 +2860,14 @@ def sample_images( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") - image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0] + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + ).images[0] ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"