diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6bab0bb8..8a185170 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1649,10 +1649,11 @@ def get_unweighted_text_embeddings( if pad == eos: # v1 text_input_chunk[:, -1] = text_input[0, -1] else: # v2 - if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある - text_input_chunk[:, -1] = eos - if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD - text_input_chunk[:, 1] = eos + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos if clip_skip is None or clip_skip == 1: text_embedding = pipe.text_encoder(text_input_chunk)[0] @@ -2276,13 +2277,26 @@ def main(args): mask_images = l # 画像サイズにオプション指定があるときはリサイズする - if init_images is not None and args.W is not None and args.H is not None: - print(f"resize img2img source images to {args.W}*{args.H}") - init_images = resize_images(init_images, (args.W, args.H)) + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) if mask_images is not None: print(f"resize img2img mask images to {args.W}*{args.H}") mask_images = resize_images(mask_images, (args.W, args.H)) + if networks and mask_images: + # mask を領域情報として流用する、現在は1枚だけ対応 + # TODO 複数のnetwork classの混在時の考慮 + print("use mask as region") + # import cv2 + # for i in range(3): + # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) + # cv2.waitKey() + # cv2.destroyAllWindows() + networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + mask_images = None + prev_image = None # for VGG16 guided if args.guide_image_path is not None: print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") diff --git a/networks/lora.py b/networks/lora.py index 24b107ba..2318605a 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -6,6 +6,7 @@ import math import os from typing import List +import numpy as np import torch from library import train_util @@ -45,14 +46,50 @@ class LoRAModule(torch.nn.Module): self.multiplier = multiplier self.org_module = org_module # remove in applying + self.region = None + self.region_mask = None def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module + def set_region(self, region): + self.region = region + def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.region is None: + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + # reginal LoRA + if x.size()[1] % 77 == 0: + # print(f"LoRA for context: {self.lora_name}") + self.region = None + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + if self.region_mask is None: + if len(x.size()) == 4: + h, w = x.size()[2:4] + else: + seq_len = x.size()[1] + ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) + h = int(self.region.size()[0] / ratio + .5) + w = seq_len // h + + r = self.region.to(x.device) + if r.dtype == torch.bfloat16: + r = r.to(torch.float) + r = r.unsqueeze(0).unsqueeze(1) + # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) + r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear') + r = r.to(x.dtype) + + if len(x.size()) == 3: + r = torch.reshape(r, (1, x.size()[1], -1)) + + self.region_mask = r + + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): @@ -130,7 +167,7 @@ class LoRANetwork(torch.nn.Module): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier - + def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import load_file, safe_open @@ -240,3 +277,18 @@ class LoRANetwork(torch.nn.Module): save_file(state_dict, file, metadata) else: torch.save(state_dict, file) + + @staticmethod + def set_regions(networks, image): + image = image.astype(np.float32) / 255.0 + for i, network in enumerate(networks[:3]): + # NOTE: consider averaging overwrapping area + region = image[:, :, i] + if region.max() == 0: + continue + region = torch.tensor(region) + network.set_region(region) + + def set_region(self, region): + for lora in self.unet_loras: + lora.set_region(region)