mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support LoRA merge in advance
This commit is contained in:
@@ -2271,6 +2271,7 @@ def main(args):
|
|||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not args.network_merge:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
@@ -2278,6 +2279,9 @@ def main(args):
|
|||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
|
else:
|
||||||
|
network.merge_to(text_encoder, unet, dtype, device)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
networks = []
|
networks = []
|
||||||
|
|
||||||
@@ -3074,6 +3078,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||||
)
|
)
|
||||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||||
|
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--textual_inversion_embeddings",
|
"--textual_inversion_embeddings",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -66,6 +66,37 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
|
def merge_to(self, sd, dtype, device):
|
||||||
|
# get up/down weight
|
||||||
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||||
|
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||||
|
|
||||||
|
# extract weight from org_module
|
||||||
|
org_sd = self.org_module.state_dict()
|
||||||
|
weight = org_sd["weight"].to(torch.float)
|
||||||
|
|
||||||
|
# merge weight
|
||||||
|
if len(weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
weight
|
||||||
|
+ self.multiplier
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
|
# set weight to org_module
|
||||||
|
org_sd["weight"] = weight.to(dtype)
|
||||||
|
self.org_module.load_state_dict(org_sd)
|
||||||
|
|
||||||
def set_region(self, region):
|
def set_region(self, region):
|
||||||
self.region = region
|
self.region = region
|
||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
@@ -344,6 +375,35 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
info = self.load_state_dict(self.weights_sd, False)
|
info = self.load_state_dict(self.weights_sd, False)
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
|
|
||||||
|
# TODO refactor to common function with apply_to
|
||||||
|
def merge_to(self, text_encoder, unet, dtype, device):
|
||||||
|
assert self.weights_sd is not None, "weights are not loaded"
|
||||||
|
|
||||||
|
apply_text_encoder = apply_unet = False
|
||||||
|
for key in self.weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
|
if apply_text_encoder:
|
||||||
|
print("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
print("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in self.weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
print(f"weights are merged")
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
def enable_gradient_checkpointing(self):
|
||||||
# not supported
|
# not supported
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user