add convenience function to merge LoRA

This commit is contained in:
Kohya S
2023-07-17 10:30:57 +09:00
parent 1d4672d747
commit 7e20c6d1a1

View File

@@ -257,6 +257,15 @@ def create_network_from_weights(
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
unet = pipe.unet
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
lora_network.load_state_dict(weights_sd)
lora_network.merge_to(multiplier=multiplier)
# block weightや学習に対応しない簡易版 / simple version without block weight and training # block weightや学習に対応しない簡易版 / simple version without block weight and training
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
@@ -432,7 +441,7 @@ class LoRANetwork(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict # convert SDXL Stability AI's state dict to Diffusers' based state dict
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort() map_keys.sort()
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
@@ -584,3 +593,12 @@ if __name__ == "__main__":
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png") image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png")