Add region control for LoRA

This commit is contained in:
Kohya S
2023-03-04 18:03:11 +09:00
parent 45945f698a
commit fe4f4446f1
2 changed files with 75 additions and 9 deletions

View File

@@ -1649,10 +1649,11 @@ def get_unweighted_text_embeddings(
if pad == eos: # v1 if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1] text_input_chunk[:, -1] = text_input[0, -1]
else: # v2 else: # v2
if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある for j in range(len(text_input_chunk)):
text_input_chunk[:, -1] = eos if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD text_input_chunk[j, -1] = eos
text_input_chunk[:, 1] = eos if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1: if clip_skip is None or clip_skip == 1:
text_embedding = pipe.text_encoder(text_input_chunk)[0] text_embedding = pipe.text_encoder(text_input_chunk)[0]
@@ -2276,13 +2277,26 @@ def main(args):
mask_images = l mask_images = l
# 画像サイズにオプション指定があるときはリサイズする # 画像サイズにオプション指定があるときはリサイズする
if init_images is not None and args.W is not None and args.H is not None: if args.W is not None and args.H is not None:
if init_images is not None:
print(f"resize img2img source images to {args.W}*{args.H}") print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H)) init_images = resize_images(init_images, (args.W, args.H))
if mask_images is not None: if mask_images is not None:
print(f"resize img2img mask images to {args.W}*{args.H}") print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H)) mask_images = resize_images(mask_images, (args.W, args.H))
if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
mask_images = None
prev_image = None # for VGG16 guided prev_image = None # for VGG16 guided
if args.guide_image_path is not None: if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")

View File

@@ -6,6 +6,7 @@
import math import math
import os import os
from typing import List from typing import List
import numpy as np
import torch import torch
from library import train_util from library import train_util
@@ -45,15 +46,51 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.region = None
self.region_mask = None
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def set_region(self, region):
self.region = region
def forward(self, x): def forward(self, x):
if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# reginal LoRA
if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}")
self.region = None
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
if self.region_mask is None:
if len(x.size()) == 4:
h, w = x.size()[2:4]
else:
seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + .5)
w = seq_len // h
r = self.region.to(x.device)
if r.dtype == torch.bfloat16:
r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
r = r.to(x.dtype)
if len(x.size()) == 3:
r = torch.reshape(r, (1, x.size()[1], -1))
self.region_mask = r
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
@@ -240,3 +277,18 @@ class LoRANetwork(torch.nn.Module):
save_file(state_dict, file, metadata) save_file(state_dict, file, metadata)
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)
@staticmethod
def set_regions(networks, image):
image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region):
for lora in self.unet_loras:
lora.set_region(region)