Add LoRA dropout

Dropout weights of rows and columns of LoRA down/up
LoRA Dropout as a Sparsity Regularizer for Overfitting Control
This commit is contained in:
rockerBOO
2025-04-13 02:26:06 -04:00
parent 5a18a03ffc
commit 4230f882d1
2 changed files with 76 additions and 32 deletions

23
library/network_utils.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
from torch import Tensor
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
def lora_dropout_down(down: Tensor, x: Tensor, dropout_prob=0.5):
""" A = A · diag(mA), mA Bern(1 p)"""
mask = torch.bernoulli(
torch.ones(down.shape[1], device=down.device) * (1 - dropout_prob)
)
# Apply input dimension mask (columns of down-projection)
lx = x @ (down * mask.view(1, -1)).t()
return lx
def lora_dropout_up(up: Tensor, x: Tensor, dropout_prob=0.5):
""" B = B · diag(mB ) , mB Bern(1 p)"""
mask = torch.bernoulli(
torch.ones(up.shape[0], device=up.device) * (1 - dropout_prob)
)
# Apply output dimension mask (rows of up-projection)
lx = x @ (up * mask.view(-1, 1)).t()
return lx

View File

@@ -18,7 +18,7 @@ import torch
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.network_utils import lora_dropout_down, lora_dropout_up
setup_logging()
import logging
@@ -45,6 +45,7 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
lora_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
@@ -106,6 +107,7 @@ class LoRAModule(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.lora_dropout = lora_dropout
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -132,7 +134,11 @@ class LoRAModule(torch.nn.Module):
return org_forwarded
if self.split_dims is None:
lx = self.lora_down(x)
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
lx = lora_dropout_down(self.lora_down.weight, x, dropout_prob=self.lora_dropout)
else:
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
@@ -153,14 +159,26 @@ class LoRAModule(torch.nn.Module):
else:
scale = self.scale
lx = self.lora_up(lx)
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
lx = lora_dropout_up(self.lora_up.weight, lx, dropout_prob=self.lora_dropout)
else:
lx = self.lora_up(lx)
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
if (
self.training
and self.ggpo_sigma is not None
and self.ggpo_beta is not None
and self.combined_weight_norms is not None
and self.grad_norms is not None
):
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
self.ggpo_beta * (self.grad_norms**2)
)
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
@@ -197,24 +215,24 @@ class LoRAModule(torch.nn.Module):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@@ -223,37 +241,36 @@ class LoRAModule(torch.nn.Module):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
return {
"true_mean_norm": true_mean_norm,
"estimated_norm": estimated_norm,
"absolute_error": absolute_error,
"relative_error": relative_error,
}
@torch.no_grad()
def update_norms(self):
@@ -261,7 +278,7 @@ class LoRAModule(torch.nn.Module):
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
# only update norms when we are training
if self.training is False:
return
@@ -269,8 +286,9 @@ class LoRAModule(torch.nn.Module):
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
self.combined_weight_norms = torch.sqrt(
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
)
@torch.no_grad()
def update_grad_norms(self):
@@ -293,7 +311,6 @@ class LoRAModule(torch.nn.Module):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@@ -544,6 +561,9 @@ def create_network(
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
lora_dropout = kwargs.get("lora_dropout", None)
if lora_dropout is not None:
lora_dropout = float(lora_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
@@ -564,7 +584,6 @@ def create_network(
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -585,6 +604,7 @@ def create_network(
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
lora_dropout=lora_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
@@ -696,6 +716,7 @@ class LoRANetwork(torch.nn.Module):
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
lora_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
@@ -722,6 +743,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.lora_dropout = lora_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
@@ -757,7 +779,6 @@ class LoRANetwork(torch.nn.Module):
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
@@ -876,6 +897,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
lora_dropout=lora_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
@@ -895,7 +917,7 @@ class LoRANetwork(torch.nn.Module):
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
logger.info(f"create LoRA for Text Encoder {index + 1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
@@ -976,7 +998,6 @@ class LoRANetwork(torch.nn.Module):
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file