fix dylora create_modules error

This commit is contained in:
tamlog06
2024-02-18 13:20:47 +00:00
parent cd19df49cd
commit a6f1ed2e14

View File

@@ -12,7 +12,9 @@
import math import math
import os import os
import random import random
from typing import List, Tuple, Union from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch import torch
from torch import nn from torch import nn
@@ -165,7 +167,15 @@ class DyLoRAModule(torch.nn.Module):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
**kwargs,
):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
if network_alpha is None: if network_alpha is None:
@@ -182,6 +192,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha = 1.0 conv_alpha = 1.0
else: else:
conv_alpha = float(conv_alpha) conv_alpha = float(conv_alpha)
if unit is not None: if unit is not None:
unit = int(unit) unit = int(unit)
else: else:
@@ -307,7 +318,21 @@ class DyLoRANetwork(torch.nn.Module):
loras.append(lora) loras.append(lora)
return loras return loras
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}")
else:
index = None
print(f"create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights