Fix TE key names for SD1/2 LoRA are invalid

This commit is contained in:
Kohya S
2023-07-08 09:56:38 +09:00
parent c1d62383c6
commit 66c03be45f

View File

@@ -735,7 +735,7 @@ class LoRANetwork(torch.nn.Module):
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
@@ -877,7 +877,14 @@ class LoRANetwork(torch.nn.Module):
self.text_encoder_loras = [] self.text_encoder_loras = []
skipped_te = [] skipped_te = []
for i, text_encoder in enumerate(text_encoders): for i, text_encoder in enumerate(text_encoders):
text_encoder_loras, skipped = create_modules(False, i + 1, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras) self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")