Rewrote the NRVDataset to be cleaner
This commit is contained in:
@@ -1,30 +1,28 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch
|
||||
from properscoring import crps_ensemble
|
||||
|
||||
|
||||
class CRPSLoss(nn.Module):
|
||||
def __init__(self, quantiles):
|
||||
def __init__(self):
|
||||
super(CRPSLoss, self).__init__()
|
||||
|
||||
if not torch.is_tensor(quantiles):
|
||||
quantiles = torch.tensor(quantiles, dtype=torch.float32)
|
||||
self.quantiles_tensor = quantiles
|
||||
|
||||
def forward(self, preds, target):
|
||||
# if tensor, to cpu
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().cpu()
|
||||
|
||||
if isinstance(target, torch.Tensor):
|
||||
target = target.detach().cpu()
|
||||
|
||||
# target squeeze -1
|
||||
target = target.squeeze(-1)
|
||||
|
||||
# preds shape: [batch_size, num_quantiles]
|
||||
|
||||
# unsqueeze target
|
||||
# target = target.unsqueeze(-1)
|
||||
|
||||
mask = (preds > target).float()
|
||||
self.quantiles_tensor = self.quantiles_tensor.to(preds.device)
|
||||
test = self.quantiles_tensor - mask
|
||||
# square them
|
||||
test = test * test
|
||||
crps = torch.trapz(test, x=preds)
|
||||
scores = crps_ensemble(target, preds)
|
||||
|
||||
# mean over batch
|
||||
crps = torch.mean(crps)
|
||||
crps = scores.mean()
|
||||
|
||||
return crps
|
||||
|
||||
Reference in New Issue
Block a user