mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
27 lines
859 B
Python
27 lines
859 B
Python
from torch import nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
|
|
|
|
class AID(nn.Module):
|
|
def __init__(self, dropout_prob=0.9):
|
|
super(AID, self).__init__()
|
|
self.p = dropout_prob
|
|
self.training = True
|
|
|
|
def forward(self, x):
|
|
if self.training:
|
|
# Generate masks for positive and negative values
|
|
pos_mask = torch.bernoulli(torch.full_like(x, self.p))
|
|
neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p))
|
|
|
|
# Apply masks to positive and negative parts
|
|
pos_part = F.relu(x) * pos_mask
|
|
neg_part = F.relu(-x) * neg_mask * -1
|
|
|
|
return pos_part + neg_part
|
|
else:
|
|
# During testing, use modified leaky ReLU with coefficient p
|
|
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1
|
|
|