fix LoRA always expanded to Conv2d-3x3

This commit is contained in:
Kohya S
2023-03-09 08:47:13 +09:00
parent 3ce846525b
commit c4b4d1cb40

View File

@@ -70,7 +70,7 @@ class LoRAModule(torch.nn.Module):
if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# reginal LoRA
# regional LoRA FIXME same as additional-network extension
if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}")
self.region = None
@@ -107,10 +107,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
network_dim = 4 # default
# extract dim/alpha for conv2d, and block dim
conv_dim = int(kwargs.get('conv_dim', network_dim))
conv_alpha = kwargs.get('conv_alpha', network_alpha)
if conv_alpha is not None:
conv_alpha = float(conv_alpha)
conv_dim = kwargs.get('conv_dim', None)
conv_alpha = kwargs.get('conv_alpha', None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = float(conv_dim)
else:
conv_alpha = float(conv_alpha)
"""
block_dims = kwargs.get("block_dims")