Files
Kohya-ss-sd-scripts/library/model_utils.py
2025-04-13 13:57:52 -04:00

27 lines
859 B
Python

from torch import nn
import torch.nn.functional as F
import torch
class AID(nn.Module):
def __init__(self, dropout_prob=0.9):
super(AID, self).__init__()
self.p = dropout_prob
self.training = True
def forward(self, x):
if self.training:
# Generate masks for positive and negative values
pos_mask = torch.bernoulli(torch.full_like(x, self.p))
neg_mask = torch.bernoulli(torch.full_like(x, 1 - self.p))
# Apply masks to positive and negative parts
pos_part = F.relu(x) * pos_mask
neg_part = F.relu(-x) * neg_mask * -1
return pos_part + neg_part
else:
# During testing, use modified leaky ReLU with coefficient p
return self.p * F.relu(x) + (1 - self.p) * F.relu(-x) * -1