From c2f75f43a41f9d1abefe90e5f458d262be37ef53 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 15 Apr 2025 04:41:59 -0400 Subject: [PATCH] Do not convert to float --- library/model_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/model_utils.py b/library/model_utils.py index 3494e6f4..87c2d04a 100644 --- a/library/model_utils.py +++ b/library/model_utils.py @@ -9,10 +9,10 @@ class AID(nn.Module): def forward(self, x: Tensor): if self.training: - pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p) - neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) + pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p) + neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) return x * (pos_mask + neg_mask) else: - pos_part = (x >= 0).float() * x * self.p - neg_part = (x < 0).float() * x * (1 - self.p) + pos_part = (x >= 0) * x * self.p + neg_part = (x < 0) * x * (1 - self.p) return pos_part + neg_part