Sped up sampling 20x

This commit is contained in:
Victor Mylle
2023-11-25 18:09:42 +00:00
parent 5de3f64a1a
commit 300f268286
10 changed files with 498 additions and 238 deletions

View File

@@ -15,7 +15,7 @@ class CRPSLoss(nn.Module):
# preds shape: [batch_size, num_quantiles]
# unsqueeze target
target = target.unsqueeze(-1)
# target = target.unsqueeze(-1)
mask = (preds > target).float()
test = self.quantiles_tensor - mask

View File

@@ -1,24 +1,27 @@
import torch
from torch import nn
class PinballLoss(nn.Module):
def __init__(self, quantiles):
super(PinballLoss, self).__init__()
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
self.quantiles = self.quantiles_tensor.tolist()
def forward(self, pred, target):
error = target - pred
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
lower = (self.quantiles_tensor - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(torch.mean(losses, dim=0))
return loss
class NonAutoRegressivePinballLoss(nn.Module):
def __init__(self, quantiles):
super(NonAutoRegressivePinballLoss, self).__init__()
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
self.quantiles = self.quantiles_tensor.tolist()
def forward(self, pred, target):
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))