mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add convenience function to merge LoRA
This commit is contained in:
@@ -257,6 +257,15 @@ def create_network_from_weights(
|
|||||||
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||||
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
||||||
|
unet = pipe.unet
|
||||||
|
|
||||||
|
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
||||||
|
lora_network.load_state_dict(weights_sd)
|
||||||
|
lora_network.merge_to(multiplier=multiplier)
|
||||||
|
|
||||||
|
|
||||||
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
@@ -432,7 +441,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||||
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||||
map_keys.sort()
|
map_keys.sort()
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
||||||
@@ -584,3 +593,12 @@ if __name__ == "__main__":
|
|||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "restore_original.png")
|
image.save(image_prefix + "restore_original.png")
|
||||||
|
|
||||||
|
# use convenience function to merge LoRA weights
|
||||||
|
print(f"merge LoRA weights with convenience function")
|
||||||
|
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||||
|
|
||||||
|
print(f"create image with merged LoRA weights")
|
||||||
|
seed_everything(args.seed)
|
||||||
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
|
image.save(image_prefix + "convenience_merged_lora.png")
|
||||||
|
|||||||
Reference in New Issue
Block a user