mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add LoRA dropout
Dropout weights of rows and columns of LoRA down/up LoRA Dropout as a Sparsity Regularizer for Overfitting Control
This commit is contained in:
23
library/network_utils.py
Normal file
23
library/network_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
|
||||
def lora_dropout_down(down: Tensor, x: Tensor, dropout_prob=0.5):
|
||||
""" A = A · diag(mA), mA ∼ Bern(1 − p)"""
|
||||
mask = torch.bernoulli(
|
||||
torch.ones(down.shape[1], device=down.device) * (1 - dropout_prob)
|
||||
)
|
||||
|
||||
# Apply input dimension mask (columns of down-projection)
|
||||
lx = x @ (down * mask.view(1, -1)).t()
|
||||
return lx
|
||||
|
||||
def lora_dropout_up(up: Tensor, x: Tensor, dropout_prob=0.5):
|
||||
""" B = B⊤ · diag(mB )⊤ , mB ∼ Bern(1 − p)"""
|
||||
mask = torch.bernoulli(
|
||||
torch.ones(up.shape[0], device=up.device) * (1 - dropout_prob)
|
||||
)
|
||||
|
||||
# Apply output dimension mask (rows of up-projection)
|
||||
lx = x @ (up * mask.view(-1, 1)).t()
|
||||
return lx
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
from library.network_utils import lora_dropout_down, lora_dropout_up
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -45,6 +45,7 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
lora_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
@@ -106,6 +107,7 @@ class LoRAModule(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.lora_dropout = lora_dropout
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
@@ -132,7 +134,11 @@ class LoRAModule(torch.nn.Module):
|
||||
return org_forwarded
|
||||
|
||||
if self.split_dims is None:
|
||||
lx = self.lora_down(x)
|
||||
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
|
||||
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
|
||||
lx = lora_dropout_down(self.lora_down.weight, x, dropout_prob=self.lora_dropout)
|
||||
else:
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
@@ -153,14 +159,26 @@ class LoRAModule(torch.nn.Module):
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
|
||||
if self.lora_dropout is not None and self.training and self.lora_dropout > 0:
|
||||
lx = lora_dropout_up(self.lora_up.weight, lx, dropout_prob=self.lora_dropout)
|
||||
else:
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
if (
|
||||
self.training
|
||||
and self.ggpo_sigma is not None
|
||||
and self.ggpo_beta is not None
|
||||
and self.combined_weight_norms is not None
|
||||
and self.grad_norms is not None
|
||||
):
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
|
||||
self.ggpo_beta * (self.grad_norms**2)
|
||||
)
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
@@ -197,24 +215,24 @@ class LoRAModule(torch.nn.Module):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@@ -223,37 +241,36 @@ class LoRAModule(torch.nn.Module):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
return {
|
||||
"true_mean_norm": true_mean_norm,
|
||||
"estimated_norm": estimated_norm,
|
||||
"absolute_error": absolute_error,
|
||||
"relative_error": relative_error,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
@@ -261,7 +278,7 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
@@ -269,8 +286,9 @@ class LoRAModule(torch.nn.Module):
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
self.combined_weight_norms = torch.sqrt(
|
||||
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
@@ -293,7 +311,6 @@ class LoRAModule(torch.nn.Module):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -544,6 +561,9 @@ def create_network(
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
lora_dropout = kwargs.get("lora_dropout", None)
|
||||
if lora_dropout is not None:
|
||||
lora_dropout = float(lora_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
@@ -564,7 +584,6 @@ def create_network(
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -585,6 +604,7 @@ def create_network(
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
lora_dropout=lora_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
@@ -696,6 +716,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
lora_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
@@ -722,6 +743,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.lora_dropout = lora_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
@@ -757,7 +779,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -876,6 +897,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
lora_dropout=lora_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
@@ -895,7 +917,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
break
|
||||
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||
logger.info(f"create LoRA for Text Encoder {index + 1}:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||
@@ -976,7 +998,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
Reference in New Issue
Block a user