Lot of changes
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss
|
||||
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss
|
||||
from .crps_metric import CRPSLoss
|
||||
|
||||
29
src/losses/crps_metric.py
Normal file
29
src/losses/crps_metric.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class CRPSLoss(nn.Module):
|
||||
def __init__(self, quantiles):
|
||||
super(CRPSLoss, self).__init__()
|
||||
|
||||
if not torch.is_tensor(quantiles):
|
||||
quantiles = torch.tensor(quantiles, dtype=torch.float32)
|
||||
self.quantiles_tensor = quantiles
|
||||
|
||||
def forward(self, preds, target):
|
||||
# preds shape: [batch_size, num_quantiles]
|
||||
|
||||
# unsqueeze target
|
||||
target = target.unsqueeze(-1)
|
||||
|
||||
mask = (preds > target).float()
|
||||
test = self.quantiles_tensor - mask
|
||||
# square them
|
||||
test = test * test
|
||||
crps = torch.trapz(test, x=preds)
|
||||
|
||||
# mean over batch
|
||||
crps = torch.mean(crps)
|
||||
|
||||
return crps
|
||||
Reference in New Issue
Block a user