From 66c03be45f30441cc01aaf7496c1339007de4cf1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Jul 2023 09:56:38 +0900 Subject: [PATCH] Fix TE key names for SD1/2 LoRA are invalid --- networks/lora.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index b6788b99..cd73cbe7 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -735,7 +735,7 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" - + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" @@ -877,7 +877,14 @@ class LoRANetwork(torch.nn.Module): self.text_encoder_loras = [] skipped_te = [] for i, text_encoder in enumerate(text_encoders): - text_encoder_loras, skipped = create_modules(False, i + 1, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")