support for controlnet in sample output

This commit is contained in:
ddPn08
2023-06-01 09:47:37 +09:00
parent 62d00b4520
commit 3bd00b88c2
4 changed files with 159 additions and 28 deletions

View File

@@ -6,7 +6,7 @@ import re
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL.Image
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8):
return mask return mask
def prepare_controlnet_image(
image: PIL.Image.Image,
width: int,
height: int,
batch_size: int,
num_images_per_prompt: int,
device: torch.device,
dtype: torch.dtype,
do_classifier_free_guidance: bool = False,
guess_mode: bool = False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
@@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
max_embeddings_multiples: Optional[int] = 3, max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
controlnet=None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: int = 1, callback_steps: int = 1,
@@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
controlnet (`diffusers.ControlNetModel`, *optional*):
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
inference.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
@@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if controlnet is not None and controlnet_image is None:
raise ValueError("controlnet_image must be provided if controlnet is not None.")
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -824,6 +887,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
else: else:
mask = None mask = None
if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
@@ -851,8 +918,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:

View File

@@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
return new_state_dict return new_state_dict
def controlnet_conversion_map():
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map = [ unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"), ("time_embed.0.bias", "time_embedding.linear_1.bias"),
@@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
sd_prefix = f"zero_convs.{i}.0." sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix)) unet_conversion_map_layer.append((sd_prefix, hf_prefix))
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
mapping = {k: k for k in controlnet_state_dict.keys()} mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map: for sd_name, diffusers_name in unet_conversion_map:
mapping[diffusers_name] = sd_name mapping[diffusers_name] = sd_name
@@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict return new_state_dict
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[sd_name] = diffusers_name
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
for k, v in mapping.items():
if "resnets" in v:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================# # ================#
# VAE Conversion # # VAE Conversion #
@@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.

View File

@@ -1674,7 +1674,6 @@ class ControlNetDataset(BaseDataset):
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img) cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img) conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)
example = {} example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights) example["loss_weights"] = torch.FloatTensor(loss_weights)
@@ -1699,7 +1698,7 @@ class ControlNetDataset(BaseDataset):
if self.debug_dataset: if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size] example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float() example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
return example return example
@@ -3138,13 +3137,13 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype return weight_dtype, save_dtype
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
name_or_path = args.pretrained_model_name_or_path name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format: if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}") print(f"load StableDiffusion checkpoint: {name_or_path}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
else: else:
# Diffusers model is loaded to CPU # Diffusers model is loaded to CPU
print(f"load Diffusers pretrained models: {name_or_path}") print(f"load Diffusers pretrained models: {name_or_path}")
@@ -3172,14 +3171,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None):
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def load_target_model(args, weight_dtype, accelerator): def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process # load models for each process
for pi in range(accelerator.state.num_processes): for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index: if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu" args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
) )
# work on low-ram device # work on low-ram device
@@ -3493,7 +3492,7 @@ SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images( def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
): ):
""" """
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
@@ -3609,6 +3608,7 @@ def sample_images(
height = prompt.get("height", 512) height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5) scale = prompt.get("scale", 7.5)
seed = prompt.get("seed") seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt") prompt = prompt.get("prompt")
else: else:
# prompt = prompt.strip() # prompt = prompt.strip()
@@ -3623,6 +3623,7 @@ def sample_images(
width = height = 512 width = height = 512
scale = 7.5 scale = 7.5
seed = None seed = None
controlnet_image = None
for parg in prompt_args: for parg in prompt_args:
try: try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE) m = re.match(r"w (\d+)", parg, re.IGNORECASE)
@@ -3655,6 +3656,12 @@ def sample_images(
negative_prompt = m.group(1) negative_prompt = m.group(1)
continue continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
@@ -3668,6 +3675,10 @@ def sample_images(
if negative_prompt is not None: if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8 height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}") print(f"prompt: {prompt}")
@@ -3683,6 +3694,8 @@ def sample_images(
num_inference_steps=sample_steps, num_inference_steps=sample_steps,
guidance_scale=scale, guidance_scale=scale,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0] ).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())

View File

@@ -1,5 +1,6 @@
import argparse import argparse
import gc import gc
import json
import math import math
import os import os
import random import random
@@ -11,6 +12,7 @@ import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
@@ -26,9 +28,6 @@ from library.custom_train_functions import (
pyramid_noise_like, pyramid_noise_like,
apply_noise_offset, apply_noise_offset,
) )
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_controlnet_from_original_ckpt,
)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@@ -124,19 +123,24 @@ def train(args):
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model( text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
) )
controlnet = ControlNetModel.from_unet(unet)
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
if os.path.isfile(args.controlnet_model_name_or_path): filename = args.controlnet_model_name_or_path
controlnet = download_controlnet_from_original_ckpt( if os.path.isfile(filename):
args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors":
) state_dict = load_file(filename)
else: else:
controlnet = ControlNetModel.from_pretrained( state_dict = torch.load(filename)
args.controlnet_model_name_or_path state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
) controlnet.load_state_dict(state_dict)
else: elif os.path.isdir(filename):
controlnet = ControlNetModel.from_unet(unet) controlnet = ControlNetModel.from_pretrained(filename)
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -289,7 +293,9 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers( accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name "controlnet_train"
if args.log_tracker_name is None
else args.log_tracker_name
) )
loss_list = [] loss_list = []
@@ -350,7 +356,7 @@ def train(args):
b_size = latents.shape[0] b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.gethidden_states( encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, weight_dtype args, input_ids, tokenizer, text_encoder, weight_dtype
) )
@@ -450,6 +456,7 @@ def train(args):
tokenizer, tokenizer,
text_encoder, text_encoder,
unet, unet,
controlnet=controlnet,
) )
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
@@ -537,6 +544,7 @@ def train(args):
tokenizer, tokenizer,
text_encoder, text_encoder,
unet, unet,
controlnet=controlnet,
) )
# end of epoch # end of epoch
@@ -569,6 +577,13 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument( parser.add_argument(
"--controlnet_model_name_or_path", "--controlnet_model_name_or_path",
type=str, type=str,