mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support original controlnet
This commit is contained in:
@@ -80,6 +80,8 @@ from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import library.model_util as model_util
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -486,6 +488,9 @@ class PipelineLike():
|
||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||
|
||||
# ControlNet
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
@@ -499,7 +504,11 @@ class PipelineLike():
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def set_control_nets(self, ctrl_nets):
|
||||
self.control_nets = ctrl_nets
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
@@ -751,7 +760,7 @@ class PipelineLike():
|
||||
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
||||
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
||||
|
||||
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
||||
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
||||
if isinstance(clip_guide_images, PIL.Image.Image):
|
||||
clip_guide_images = [clip_guide_images]
|
||||
|
||||
@@ -764,7 +773,7 @@ class PipelineLike():
|
||||
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
||||
if len(image_embeddings_clip) == 1:
|
||||
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
||||
else:
|
||||
elif self.vgg16_guidance_scale > 0:
|
||||
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
||||
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
||||
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
||||
@@ -773,6 +782,10 @@ class PipelineLike():
|
||||
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
||||
if len(image_embeddings_vgg16) == 1:
|
||||
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
||||
else:
|
||||
# ControlNetのhintにguide imageを流用する
|
||||
# 前処理はControlNet側で行う
|
||||
pass
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
||||
@@ -863,12 +876,21 @@ class PipelineLike():
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||
|
||||
if self.control_nets:
|
||||
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
if self.control_nets:
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -2066,6 +2088,18 @@ def main(args):
|
||||
else:
|
||||
networks = []
|
||||
|
||||
# ControlNetの処理
|
||||
control_nets: List[ControlNetInfo] = []
|
||||
if args.control_net_models:
|
||||
for i, model in enumerate(args.control_net_models):
|
||||
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
||||
prep = original_control_net.load_preprocess(prep_type)
|
||||
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
text_encoder.to(memory_format=torch.channels_last)
|
||||
@@ -2079,9 +2113,14 @@ def main(args):
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(memory_format=torch.channels_last)
|
||||
|
||||
for cn in control_nets:
|
||||
cn.unet.to(memory_format=torch.channels_last)
|
||||
cn.net.to(memory_format=torch.channels_last)
|
||||
|
||||
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
||||
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
||||
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
||||
pipe.set_control_nets(control_nets)
|
||||
print("pipeline is ready.")
|
||||
|
||||
if args.diffusers_xformers:
|
||||
@@ -2215,9 +2254,12 @@ def main(args):
|
||||
|
||||
prev_image = None # for VGG16 guided
|
||||
if args.guide_image_path is not None:
|
||||
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
||||
guide_images = load_images(args.guide_image_path)
|
||||
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
|
||||
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
||||
guide_images = []
|
||||
for p in args.guide_image_path:
|
||||
guide_images.extend(load_images(p))
|
||||
|
||||
print(f"loaded {len(guide_images)} guide images for guidance")
|
||||
if len(guide_images) == 0:
|
||||
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||
guide_images = None
|
||||
@@ -2257,22 +2299,22 @@ def main(args):
|
||||
print("process 1st stage1")
|
||||
batch_1st = []
|
||||
for base, ext in batch:
|
||||
width_1st = int(width * args.highres_fix_scale + .5)
|
||||
height_1st = int(height * args.highres_fix_scale + .5)
|
||||
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
||||
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
||||
width_1st = width_1st - width_1st % 32
|
||||
height_1st = height_1st - height_1st % 32
|
||||
|
||||
bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps,
|
||||
ext.scale, ext.negative_scale, ext.strength, ext.network_muls))
|
||||
batch_1st.append(bd_1st)
|
||||
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
||||
ext.negative_scale, ext.strength, ext.network_muls)
|
||||
batch_1st.append(BatchData(base, ext_1st))
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
print("process 2nd stage1")
|
||||
batch_2nd = []
|
||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||
image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext)
|
||||
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
@@ -2328,9 +2370,13 @@ def main(args):
|
||||
all_masks_are_same = mask_images[-2] is mask_image
|
||||
|
||||
if guide_image is not None:
|
||||
guide_images.append(guide_image)
|
||||
if i > 0 and all_guide_images_are_same:
|
||||
all_guide_images_are_same = guide_images[-2] is guide_image
|
||||
if type(guide_image) is list:
|
||||
guide_images.extend(guide_image)
|
||||
all_guide_images_are_same = False
|
||||
else:
|
||||
guide_images.append(guide_image)
|
||||
if i > 0 and all_guide_images_are_same:
|
||||
all_guide_images_are_same = guide_images[-2] is guide_image
|
||||
|
||||
# make start code
|
||||
torch.manual_seed(seed)
|
||||
@@ -2353,6 +2399,14 @@ def main(args):
|
||||
if guide_images is not None and all_guide_images_are_same:
|
||||
guide_images = guide_images[0]
|
||||
|
||||
# ControlNet使用時はguide imageをリサイズする
|
||||
if control_nets:
|
||||
# TODO resampleのメソッド
|
||||
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
||||
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
||||
if len(guide_images) == 1:
|
||||
guide_images = guide_images[0]
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
@@ -2545,7 +2599,12 @@ def main(args):
|
||||
mask_image = mask_images[global_step % len(mask_images)]
|
||||
|
||||
if guide_images is not None:
|
||||
guide_image = guide_images[global_step % len(guide_images)]
|
||||
if control_nets: # 複数件の場合あり
|
||||
c = len(control_nets)
|
||||
p = global_step % (len(guide_images) // c)
|
||||
guide_image = guide_images[p * c:p * c + c]
|
||||
else:
|
||||
guide_image = guide_images[global_step % len(guide_images)]
|
||||
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
||||
if prev_image is None:
|
||||
print("Generate 1st image without guide image.")
|
||||
@@ -2646,7 +2705,8 @@ if __name__ == '__main__':
|
||||
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
||||
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
||||
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
||||
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
||||
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
||||
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
||||
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
||||
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
||||
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
||||
@@ -2656,5 +2716,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--negative_scale", type=float, default=None,
|
||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||
|
||||
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
||||
help='ControlNet models to use / 使用するControlNetのモデル名')
|
||||
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
||||
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
||||
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
||||
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
||||
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user