From 2d6faa9860c76b600bb64eb79c89a47fa38cf841 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 30 Mar 2023 21:34:36 +0900 Subject: [PATCH] support LoRA merge in advance --- gen_img_diffusers.py | 17 ++++++---- networks/lora.py | 80 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index cd0be71b..4dbe5f90 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2262,7 +2262,7 @@ def main(args): metadata = f.metadata() if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - + network = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, **net_kwargs ) @@ -2271,13 +2271,17 @@ def main(args): if network is None: return - network.apply_to(text_encoder, unet) + if not args.network_merge: + network.apply_to(text_encoder, unet) - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + network.merge_to(text_encoder, unet, dtype, device) - networks.append(network) else: networks = [] @@ -3074,6 +3078,7 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/networks/lora.py b/networks/lora.py index d5c07aec..2bf78511 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -66,6 +66,37 @@ class LoRAModule(torch.nn.Module): self.org_module.forward = self.forward del self.org_module + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + def set_region(self, region): self.region = region self.region_mask = None @@ -121,30 +152,30 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = float(conv_alpha) """ - block_dims = kwargs.get("block_dims") - block_alphas = None + block_dims = kwargs.get("block_dims") + block_alphas = None - if block_dims is not None: + if block_dims is not None: block_dims = [int(d) for d in block_dims.split(',')] assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" block_alphas = kwargs.get("block_alphas") if block_alphas is None: - block_alphas = [1] * len(block_dims) + block_alphas = [1] * len(block_dims) else: - block_alphas = [int(a) for a in block_alphas(',')] + block_alphas = [int(a) for a in block_alphas(',')] assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - conv_block_dims = kwargs.get("conv_block_dims") - conv_block_alphas = None + conv_block_dims = kwargs.get("conv_block_dims") + conv_block_alphas = None - if conv_block_dims is not None: + if conv_block_dims is not None: conv_block_dims = [int(d) for d in conv_block_dims.split(',')] assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" conv_block_alphas = kwargs.get("conv_block_alphas") if conv_block_alphas is None: - conv_block_alphas = [1] * len(conv_block_dims) + conv_block_alphas = [1] * len(conv_block_dims) else: - conv_block_alphas = [int(a) for a in conv_block_alphas(',')] + conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" """ @@ -344,6 +375,35 @@ class LoRANetwork(torch.nn.Module): info = self.load_state_dict(self.weights_sd, False) print(f"weights are loaded: {info}") + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, dtype, device): + assert self.weights_sd is not None, "weights are not loaded" + + apply_text_encoder = apply_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in self.weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + print(f"weights are merged") + def enable_gradient_checkpointing(self): # not supported pass