From 7e20c6d1a132d09ed087ef9d054645d9c9b68aca Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 17 Jul 2023 10:30:57 +0900 Subject: [PATCH] add convenience function to merge LoRA --- networks/lora_diffusers.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index f06212df..c41111be 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -257,6 +257,15 @@ def create_network_from_weights( return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) +def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder] + unet = pipe.unet + + lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier) + lora_network.load_state_dict(weights_sd) + lora_network.merge_to(multiplier=multiplier) + + # block weightや学習に対応しない簡易版 / simple version without block weight and training class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] @@ -432,7 +441,7 @@ class LoRANetwork(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict - map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules + map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys.sort() for key in list(state_dict.keys()): if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): @@ -584,3 +593,12 @@ if __name__ == "__main__": seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") + + # use convenience function to merge LoRA weights + print(f"merge LoRA weights with convenience function") + merge_lora_weights(pipe, lora_sd, multiplier=1.0) + + print(f"create image with merged LoRA weights") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "convenience_merged_lora.png")