support LoRA merge in advance

This commit is contained in:
Kohya S
2023-03-30 21:34:36 +09:00
parent cb53a77334
commit 2d6faa9860
2 changed files with 81 additions and 16 deletions

View File

@@ -2271,6 +2271,7 @@ def main(args):
if network is None: if network is None:
return return
if not args.network_merge:
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
if args.opt_channels_last: if args.opt_channels_last:
@@ -2278,6 +2279,9 @@ def main(args):
network.to(dtype).to(device) network.to(dtype).to(device)
networks.append(network) networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
else: else:
networks = [] networks = []
@@ -3074,6 +3078,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
) )
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument( parser.add_argument(
"--textual_inversion_embeddings", "--textual_inversion_embeddings",
type=str, type=str,

View File

@@ -66,6 +66,37 @@ class LoRAModule(torch.nn.Module):
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def merge_to(self, sd, dtype, device):
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"].to(torch.float)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def set_region(self, region): def set_region(self, region):
self.region = region self.region = region
self.region_mask = None self.region_mask = None
@@ -344,6 +375,35 @@ class LoRANetwork(torch.nn.Module):
info = self.load_state_dict(self.weights_sd, False) info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, dtype, device):
assert self.weights_sd is not None, "weights are not loaded"
apply_text_encoder = apply_unet = False
for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in self.weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
def enable_gradient_checkpointing(self): def enable_gradient_checkpointing(self):
# not supported # not supported
pass pass