mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add convenience function to merge LoRA
This commit is contained in:
@@ -257,6 +257,15 @@ def create_network_from_weights(
|
|||||||
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||||
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
||||||
|
unet = pipe.unet
|
||||||
|
|
||||||
|
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
||||||
|
lora_network.load_state_dict(weights_sd)
|
||||||
|
lora_network.merge_to(multiplier=multiplier)
|
||||||
|
|
||||||
|
|
||||||
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
@@ -584,3 +593,12 @@ if __name__ == "__main__":
|
|||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "restore_original.png")
|
image.save(image_prefix + "restore_original.png")
|
||||||
|
|
||||||
|
# use convenience function to merge LoRA weights
|
||||||
|
print(f"merge LoRA weights with convenience function")
|
||||||
|
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||||
|
|
||||||
|
print(f"create image with merged LoRA weights")
|
||||||
|
seed_everything(args.seed)
|
||||||
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
|
image.save(image_prefix + "convenience_merged_lora.png")
|
||||||
|
|||||||
Reference in New Issue
Block a user