Rewrote the NRVDataset to be cleaner

This commit is contained in:
Victor Mylle
2023-11-28 15:35:35 +00:00
parent f9e8f9e69f
commit ffa19592f9
3 changed files with 83 additions and 181 deletions

View File

@@ -1,30 +1,28 @@
import torch
from torch import nn
import torch
from properscoring import crps_ensemble
class CRPSLoss(nn.Module):
def __init__(self, quantiles):
def __init__(self):
super(CRPSLoss, self).__init__()
if not torch.is_tensor(quantiles):
quantiles = torch.tensor(quantiles, dtype=torch.float32)
self.quantiles_tensor = quantiles
def forward(self, preds, target):
# if tensor, to cpu
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu()
if isinstance(target, torch.Tensor):
target = target.detach().cpu()
# target squeeze -1
target = target.squeeze(-1)
# preds shape: [batch_size, num_quantiles]
# unsqueeze target
# target = target.unsqueeze(-1)
mask = (preds > target).float()
self.quantiles_tensor = self.quantiles_tensor.to(preds.device)
test = self.quantiles_tensor - mask
# square them
test = test * test
crps = torch.trapz(test, x=preds)
scores = crps_ensemble(target, preds)
# mean over batch
crps = torch.mean(crps)
crps = scores.mean()
return crps