mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
19 lines
560 B
Python
19 lines
560 B
Python
import torch
|
|
from torch import nn, Tensor
|
|
|
|
|
|
class AID(nn.Module):
|
|
def __init__(self, p=0.9):
|
|
super().__init__()
|
|
self.p = p
|
|
|
|
def forward(self, x: Tensor):
|
|
if self.training:
|
|
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
|
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
|
return x * (pos_mask + neg_mask)
|
|
else:
|
|
pos_part = (x >= 0) * x * self.p
|
|
neg_part = (x < 0) * x * (1 - self.p)
|
|
return pos_part + neg_part
|