mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Dropout and Max Norm Regularization for LoRA training (#545)
* Instantiate max_norm * minor * Move to end of step * argparse * metadata * phrasing * Sqrt ratio and logging * fix logging * Dropout test * Dropout Args * Dropout changed to affect LoRA only --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
@@ -60,6 +60,7 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -67,7 +68,10 @@ class LoRAModule(torch.nn.Module):
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
if self.dropout:
|
||||
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
|
||||
else:
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
@@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
varbose=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
@@ -681,6 +686,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_alpha=None,
|
||||
module_class=LoRAModule,
|
||||
varbose=False,
|
||||
dropout=None
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -697,6 +703,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.dropout = dropout
|
||||
print(f"Neuron dropout: p={self.dropout}")
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -755,7 +763,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
|
||||
|
||||
Reference in New Issue
Block a user