mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Add AID activation interval-wise dropout
This commit is contained in:
26
library/model_utils.py
Normal file
26
library/model_utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
|
||||
class AID(nn.Module):
|
||||
def __init__(self, dropout_prob=0.9):
|
||||
super(AID, self).__init__()
|
||||
self.p = dropout_prob
|
||||
self.training = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
# Generate masks for positive and negative values
|
||||
pos_mask = torch.bernoulli(torch.full_like(x, self.p))
|
||||
neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p))
|
||||
|
||||
# Apply masks to positive and negative parts
|
||||
pos_part = F.relu(x) * pos_mask
|
||||
neg_part = F.relu(-x) * neg_mask * -1
|
||||
|
||||
return pos_part + neg_part
|
||||
else:
|
||||
# During testing, use modified leaky ReLU with coefficient p
|
||||
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1
|
||||
|
||||
@@ -17,6 +17,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.model_utils import AID
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
@@ -45,6 +46,7 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
aid_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
@@ -107,6 +109,8 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.aid = AID(dropout_prob=aid_dropout) # AID activation
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
@@ -155,6 +159,9 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
if self.aid_dropout is not None and self.training:
|
||||
lx = self.aid(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad():
|
||||
@@ -544,6 +551,9 @@ def create_network(
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
aid_dropout = kwargs.get("aid_dropout", None)
|
||||
if aid_dropout is not None:
|
||||
aid_dropout = float(aid_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
@@ -585,6 +595,7 @@ def create_network(
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
@@ -696,6 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
aid_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
@@ -722,6 +734,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.aid_dropout = aid_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
@@ -876,6 +889,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
aid_dropout=aid_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
|
||||
Reference in New Issue
Block a user