Files
Kohya-ss-sd-scripts/library/network_utils.py
rockerBOO 4230f882d1 Add LoRA dropout
Dropout weights of rows and columns of LoRA down/up
LoRA Dropout as a Sparsity Regularizer for Overfitting Control
2025-04-13 02:26:06 -04:00

24 lines
796 B
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import Tensor
# LoRA Dropout as a Sparsity Regularizer for Overfitting Control
def lora_dropout_down(down: Tensor, x: Tensor, dropout_prob=0.5):
""" A = A · diag(mA), mA Bern(1 p)"""
mask = torch.bernoulli(
torch.ones(down.shape[1], device=down.device) * (1 - dropout_prob)
)
# Apply input dimension mask (columns of down-projection)
lx = x @ (down * mask.view(1, -1)).t()
return lx
def lora_dropout_up(up: Tensor, x: Tensor, dropout_prob=0.5):
""" B = B · diag(mB ) , mB Bern(1 p)"""
mask = torch.bernoulli(
torch.ones(up.shape[0], device=up.device) * (1 - dropout_prob)
)
# Apply output dimension mask (rows of up-projection)
lx = x @ (up * mask.view(-1, 1)).t()
return lx