support original controlnet

This commit is contained in:
Kohya S
2023-02-20 12:54:44 +09:00
parent d94c0d70fe
commit 014fd3d037
2 changed files with 406 additions and 18 deletions

View File

@@ -80,6 +80,8 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -486,6 +488,9 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# ControlNet
self.control_nets: List[ControlNetInfo] = []
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
@@ -499,7 +504,11 @@ class PipelineLike():
new_tokens.append(token)
return new_tokens
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
# region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
@@ -751,7 +760,7 @@ class PipelineLike():
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
@@ -764,7 +773,7 @@ class PipelineLike():
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if len(image_embeddings_clip) == 1:
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
else:
elif self.vgg16_guidance_scale > 0:
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に小さいか?
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -773,6 +782,10 @@ class PipelineLike():
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
if len(image_embeddings_vgg16) == 1:
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
else:
# ControlNetのhintにguide imageを流用する
# 前処理はControlNet側で行う
pass
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -863,12 +876,21 @@ class PipelineLike():
extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
if self.control_nets:
noise_pred = original_control_net.call_unet_and_control_net(
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@@ -2066,6 +2088,18 @@ def main(args):
else:
networks = []
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last:
print(f"set optimizing: channels last")
text_encoder.to(memory_format=torch.channels_last)
@@ -2079,9 +2113,14 @@ def main(args):
if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last)
for cn in control_nets:
cn.unet.to(memory_format=torch.channels_last)
cn.net.to(memory_format=torch.channels_last)
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
pipe.set_control_nets(control_nets)
print("pipeline is ready.")
if args.diffusers_xformers:
@@ -2215,9 +2254,12 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
guide_images = load_images(args.guide_image_path)
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None
@@ -2257,22 +2299,22 @@ def main(args):
print("process 1st stage1")
batch_1st = []
for base, ext in batch:
width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5)
width_1st = int(ext.width * args.highres_fix_scale + .5)
height_1st = int(ext.height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps,
ext.scale, ext.negative_scale, ext.strength, ext.network_muls))
batch_1st.append(bd_1st)
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
ext.negative_scale, ext.strength, ext.network_muls)
batch_1st.append(BatchData(base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage1")
batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext)
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
@@ -2328,9 +2370,13 @@ def main(args):
all_masks_are_same = mask_images[-2] is mask_image
if guide_image is not None:
guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image
if type(guide_image) is list:
guide_images.extend(guide_image)
all_guide_images_are_same = False
else:
guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image
# make start code
torch.manual_seed(seed)
@@ -2353,6 +2399,14 @@ def main(args):
if guide_images is not None and all_guide_images_are_same:
guide_images = guide_images[0]
# ControlNet使用時はguide imageをリサイズする
if control_nets:
# TODO resampleのメソッド
guide_images = guide_images if type(guide_images) == list else [guide_images]
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
if len(guide_images) == 1:
guide_images = guide_images[0]
# generate
if networks:
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
@@ -2545,7 +2599,12 @@ def main(args):
mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None:
guide_image = guide_images[global_step % len(guide_images)]
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c:p * c + c]
else:
guide_image = guide_images[global_step % len(guide_images)]
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
if prev_image is None:
print("Generate 1st image without guide image.")
@@ -2646,7 +2705,8 @@ if __name__ == '__main__':
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
parser.add_argument("--highres_fix_scale", type=float, default=None,
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
parser.add_argument("--highres_fix_steps", type=int, default=28,
@@ -2656,5 +2716,13 @@ if __name__ == '__main__':
parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
help='ControlNet models to use / 使用するControlNetのモデル名')
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
args = parser.parse_args()
main(args)