From c4b4d1cb406ed15935d3e30ffdd6d1c029093b30 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 9 Mar 2023 08:47:13 +0900 Subject: [PATCH] fix LoRA always expanded to Conv2d-3x3 --- networks/lora.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 9a5d95b9..cdc6b415 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -70,7 +70,7 @@ class LoRAModule(torch.nn.Module): if self.region is None: return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - # reginal LoRA + # regional LoRA FIXME same as additional-network extension if x.size()[1] % 77 == 0: # print(f"LoRA for context: {self.lora_name}") self.region = None @@ -107,10 +107,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un network_dim = 4 # default # extract dim/alpha for conv2d, and block dim - conv_dim = int(kwargs.get('conv_dim', network_dim)) - conv_alpha = kwargs.get('conv_alpha', network_alpha) - if conv_alpha is not None: - conv_alpha = float(conv_alpha) + conv_dim = kwargs.get('conv_dim', None) + conv_alpha = kwargs.get('conv_alpha', None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = float(conv_dim) + else: + conv_alpha = float(conv_alpha) """ block_dims = kwargs.get("block_dims")