mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add region control for LoRA
This commit is contained in:
@@ -1649,10 +1649,11 @@ def get_unweighted_text_embeddings(
|
|||||||
if pad == eos: # v1
|
if pad == eos: # v1
|
||||||
text_input_chunk[:, -1] = text_input[0, -1]
|
text_input_chunk[:, -1] = text_input[0, -1]
|
||||||
else: # v2
|
else: # v2
|
||||||
if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
|
for j in range(len(text_input_chunk)):
|
||||||
text_input_chunk[:, -1] = eos
|
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||||||
if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
|
text_input_chunk[j, -1] = eos
|
||||||
text_input_chunk[:, 1] = eos
|
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||||||
|
text_input_chunk[j, 1] = eos
|
||||||
|
|
||||||
if clip_skip is None or clip_skip == 1:
|
if clip_skip is None or clip_skip == 1:
|
||||||
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||||
@@ -2276,13 +2277,26 @@ def main(args):
|
|||||||
mask_images = l
|
mask_images = l
|
||||||
|
|
||||||
# 画像サイズにオプション指定があるときはリサイズする
|
# 画像サイズにオプション指定があるときはリサイズする
|
||||||
if init_images is not None and args.W is not None and args.H is not None:
|
if args.W is not None and args.H is not None:
|
||||||
|
if init_images is not None:
|
||||||
print(f"resize img2img source images to {args.W}*{args.H}")
|
print(f"resize img2img source images to {args.W}*{args.H}")
|
||||||
init_images = resize_images(init_images, (args.W, args.H))
|
init_images = resize_images(init_images, (args.W, args.H))
|
||||||
if mask_images is not None:
|
if mask_images is not None:
|
||||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||||
|
|
||||||
|
if networks and mask_images:
|
||||||
|
# mask を領域情報として流用する、現在は1枚だけ対応
|
||||||
|
# TODO 複数のnetwork classの混在時の考慮
|
||||||
|
print("use mask as region")
|
||||||
|
# import cv2
|
||||||
|
# for i in range(3):
|
||||||
|
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
||||||
|
# cv2.waitKey()
|
||||||
|
# cv2.destroyAllWindows()
|
||||||
|
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
||||||
|
mask_images = None
|
||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
if args.guide_image_path is not None:
|
if args.guide_image_path is not None:
|
||||||
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from library import train_util
|
from library import train_util
|
||||||
@@ -45,15 +46,51 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
|
self.region = None
|
||||||
|
self.region_mask = None
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
self.region = region
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.region is None:
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
# reginal LoRA
|
||||||
|
if x.size()[1] % 77 == 0:
|
||||||
|
# print(f"LoRA for context: {self.lora_name}")
|
||||||
|
self.region = None
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
if self.region_mask is None:
|
||||||
|
if len(x.size()) == 4:
|
||||||
|
h, w = x.size()[2:4]
|
||||||
|
else:
|
||||||
|
seq_len = x.size()[1]
|
||||||
|
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
||||||
|
h = int(self.region.size()[0] / ratio + .5)
|
||||||
|
w = seq_len // h
|
||||||
|
|
||||||
|
r = self.region.to(x.device)
|
||||||
|
if r.dtype == torch.bfloat16:
|
||||||
|
r = r.to(torch.float)
|
||||||
|
r = r.unsqueeze(0).unsqueeze(1)
|
||||||
|
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
||||||
|
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear')
|
||||||
|
r = r.to(x.dtype)
|
||||||
|
|
||||||
|
if len(x.size()) == 3:
|
||||||
|
r = torch.reshape(r, (1, x.size()[1], -1))
|
||||||
|
|
||||||
|
self.region_mask = r
|
||||||
|
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
@@ -240,3 +277,18 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_regions(networks, image):
|
||||||
|
image = image.astype(np.float32) / 255.0
|
||||||
|
for i, network in enumerate(networks[:3]):
|
||||||
|
# NOTE: consider averaging overwrapping area
|
||||||
|
region = image[:, :, i]
|
||||||
|
if region.max() == 0:
|
||||||
|
continue
|
||||||
|
region = torch.tensor(region)
|
||||||
|
network.set_region(region)
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
lora.set_region(region)
|
||||||
|
|||||||
Reference in New Issue
Block a user