Do not convert to float

This commit is contained in:
rockerBOO
2025-04-15 04:41:59 -04:00
parent 4d005cdf3d
commit c2f75f43a4

View File

@@ -9,10 +9,10 @@ class AID(nn.Module):
def forward(self, x: Tensor):
if self.training:
pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
return x * (pos_mask + neg_mask)
else:
pos_part = (x >= 0).float() * x * self.p
neg_part = (x < 0).float() * x * (1 - self.p)
pos_part = (x >= 0) * x * self.p
neg_part = (x < 0) * x * (1 - self.p)
return pos_part + neg_part