Added new training scripts
This commit is contained in:
@@ -18,6 +18,7 @@ class CRPSLoss(nn.Module):
|
||||
# target = target.unsqueeze(-1)
|
||||
|
||||
mask = (preds > target).float()
|
||||
self.quantiles_tensor = self.quantiles_tensor.to(preds.device)
|
||||
test = self.quantiles_tensor - mask
|
||||
# square them
|
||||
test = test * test
|
||||
|
||||
@@ -9,8 +9,9 @@ class PinballLoss(nn.Module):
|
||||
|
||||
def forward(self, pred, target):
|
||||
error = target - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
quantiles = self.quantiles_tensor.to(error.device)
|
||||
upper = quantiles * error
|
||||
lower = (quantiles - 1) * error
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(torch.mean(losses, dim=0))
|
||||
return loss
|
||||
@@ -26,8 +27,10 @@ class NonAutoRegressivePinballLoss(nn.Module):
|
||||
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))
|
||||
target_expanded = target.unsqueeze(2)
|
||||
error = target_expanded - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
quantiles = self.quantiles_tensor.to(error.device)
|
||||
|
||||
upper = quantiles * error
|
||||
lower = (quantiles - 1) * error
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(losses)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user