Add AID activation interval-wise dropout

This commit is contained in:
rockerBOO
2025-04-13 13:57:52 -04:00
parent 5a18a03ffc
commit dfae3a486c
2 changed files with 40 additions and 0 deletions

26
library/model_utils.py Normal file
View File

@@ -0,0 +1,26 @@
from torch import nn
import torch.nn.functional as F
import torch
class AID(nn.Module):
def __init__(self, dropout_prob=0.9):
super(AID, self).__init__()
self.p = dropout_prob
self.training = True
def forward(self, x):
if self.training:
# Generate masks for positive and negative values
pos_mask = torch.bernoulli(torch.full_like(x, self.p))
neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p))
# Apply masks to positive and negative parts
pos_part = F.relu(x) * pos_mask
neg_part = F.relu(-x) * neg_mask * -1
return pos_part + neg_part
else:
# During testing, use modified leaky ReLU with coefficient p
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1

View File

@@ -17,6 +17,7 @@ import numpy as np
import torch
from torch import Tensor
import re
from library.model_utils import AID
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -45,6 +46,7 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
aid_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
@@ -107,6 +109,8 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid = AID(dropout_prob=aid_dropout) # AID activation
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -155,6 +159,9 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
if self.aid_dropout is not None and self.training:
lx = self.aid(lx)
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
@@ -544,6 +551,9 @@ def create_network(
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
aid_dropout = kwargs.get("aid_dropout", None)
if aid_dropout is not None:
aid_dropout = float(aid_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
@@ -585,6 +595,7 @@ def create_network(
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
@@ -696,6 +707,7 @@ class LoRANetwork(torch.nn.Module):
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
aid_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
@@ -722,6 +734,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid_dropout = aid_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
@@ -876,6 +889,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,