Dropout and Max Norm Regularization for LoRA training (#545)

* Instantiate max_norm

* minor

* Move to end of step

* argparse

* metadata

* phrasing

* Sqrt ratio and logging

* fix logging

* Dropout test

* Dropout Args

* Dropout changed to affect LoRA only

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
AI-Casanova
2023-06-01 00:58:38 -05:00
committed by GitHub
parent 5931948adb
commit 9c7237157d
4 changed files with 77 additions and 9 deletions

View File

@@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
@@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -67,7 +68,10 @@ class LoRAModule(torch.nn.Module):
del self.org_module
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
@@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
dropout=dropout,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
dropout=None
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name)
continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
loras.append(lora)
return loras, skipped