mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add convenience function to merge LoRA
This commit is contained in:
@@ -257,6 +257,15 @@ def create_network_from_weights(
|
||||
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||
|
||||
|
||||
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
||||
unet = pipe.unet
|
||||
|
||||
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
||||
lora_network.load_state_dict(weights_sd)
|
||||
lora_network.merge_to(multiplier=multiplier)
|
||||
|
||||
|
||||
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
@@ -432,7 +441,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||
map_keys.sort()
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
||||
@@ -584,3 +593,12 @@ if __name__ == "__main__":
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "restore_original.png")
|
||||
|
||||
# use convenience function to merge LoRA weights
|
||||
print(f"merge LoRA weights with convenience function")
|
||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"create image with merged LoRA weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "convenience_merged_lora.png")
|
||||
|
||||
Reference in New Issue
Block a user