Lot of changes

This commit is contained in:
Victor Mylle
2023-11-23 08:34:47 +00:00
parent 166d3967e1
commit 5de3f64a1a
9 changed files with 761 additions and 196234 deletions

View File

@@ -1 +1,2 @@
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss
from .crps_metric import CRPSLoss

29
src/losses/crps_metric.py Normal file
View File

@@ -0,0 +1,29 @@
import torch
from torch import nn
import torch
class CRPSLoss(nn.Module):
def __init__(self, quantiles):
super(CRPSLoss, self).__init__()
if not torch.is_tensor(quantiles):
quantiles = torch.tensor(quantiles, dtype=torch.float32)
self.quantiles_tensor = quantiles
def forward(self, preds, target):
# preds shape: [batch_size, num_quantiles]
# unsqueeze target
target = target.unsqueeze(-1)
mask = (preds > target).float()
test = self.quantiles_tensor - mask
# square them
test = test * test
crps = torch.trapz(test, x=preds)
# mean over batch
crps = torch.mean(crps)
return crps