diff --git a/library/model_utils.py b/library/model_utils.py index 3494e6f4..87c2d04a 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -9,10 +9,10 @@ class AID(nn.Module): def forward(self, x: Tensor): if self.training: - pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p) - neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) return x * (pos_mask + neg_mask) else: - pos_part = (x >= 0).float() * x * self.p - neg_part = (x < 0).float() * x * (1 - self.p) + pos_part = (x >= 0) * x * self.p + neg_part = (x < 0) * x * (1 - self.p) return pos_part + neg_part