Add region control for LoRA

This commit is contained in:
Kohya S
2023-03-04 18:03:11 +09:00
parent 45945f698a
commit fe4f4446f1
2 changed files with 75 additions and 9 deletions

View File

@@ -1649,10 +1649,11 @@ def get_unweighted_text_embeddings(
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
text_input_chunk[:, -1] = eos
if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
text_input_chunk[:, 1] = eos
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = pipe.text_encoder(text_input_chunk)[0]
@@ -2276,13 +2277,26 @@ def main(args):
mask_images = l
# 画像サイズにオプション指定があるときはリサイズする
if init_images is not None and args.W is not None and args.H is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
if args.W is not None and args.H is not None:
if init_images is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
if mask_images is not None:
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))
if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
mask_images = None
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")