support for controlnet in sample output

This commit is contained in:
ddPn08
2023-06-01 09:47:37 +09:00
parent 62d00b4520
commit 3bd00b88c2
4 changed files with 159 additions and 28 deletions

View File

@@ -6,7 +6,7 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8):
return mask
def prepare_controlnet_image(
image: PIL.Image.Image,
width: int,
height: int,
batch_size: int,
num_images_per_prompt: int,
device: torch.device,
dtype: torch.dtype,
do_classifier_free_guidance: bool = False,
guess_mode: bool = False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
@@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: int = 1,
@@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
controlnet (`diffusers.ControlNetModel`, *optional*):
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
inference.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
@@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if controlnet is not None and controlnet_image is None:
raise ValueError("controlnet_image must be provided if controlnet is not None.")
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -824,6 +887,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
else:
mask = None
if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
@@ -851,8 +918,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
# perform guidance
if do_classifier_free_guidance: