diff --git a/networks/lora.py b/networks/lora.py index 9a5d95b9..cdc6b415 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -70,7 +70,7 @@ class LoRAModule(torch.nn.Module): if self.region is None: return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - # reginal LoRA + # regional LoRA FIXME same as additional-network extension if x.size()[1] % 77 == 0: # print(f"LoRA for context: {self.lora_name}") self.region = None @@ -107,10 +107,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un network_dim = 4 # default # extract dim/alpha for conv2d, and block dim - conv_dim = int(kwargs.get('conv_dim', network_dim)) - conv_alpha = kwargs.get('conv_alpha', network_alpha) - if conv_alpha is not None: - conv_alpha = float(conv_alpha) + conv_dim = kwargs.get('conv_dim', None) + conv_alpha = kwargs.get('conv_alpha', None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = float(conv_dim) + else: + conv_alpha = float(conv_alpha) """ block_dims = kwargs.get("block_dims")