mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Do not convert to float
This commit is contained in:
@@ -9,10 +9,10 @@ class AID(nn.Module):
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if self.training:
|
||||
pos_mask = (x >= 0).float() * torch.bernoulli(torch.ones_like(x) * self.p)
|
||||
neg_mask = (x < 0).float() * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
||||
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
|
||||
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
|
||||
return x * (pos_mask + neg_mask)
|
||||
else:
|
||||
pos_part = (x >= 0).float() * x * self.p
|
||||
neg_part = (x < 0).float() * x * (1 - self.p)
|
||||
pos_part = (x >= 0) * x * self.p
|
||||
neg_part = (x < 0) * x * (1 - self.p)
|
||||
return pos_part + neg_part
|
||||
|
||||
Reference in New Issue
Block a user