Fixed crps + more inputs
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch
|
||||
from properscoring import crps_ensemble
|
||||
import numpy as np
|
||||
from scipy.interpolate import CubicSpline
|
||||
|
||||
|
||||
class CRPSLoss(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, quantiles):
|
||||
super(CRPSLoss, self).__init__()
|
||||
self.quantiles = quantiles
|
||||
|
||||
def forward(self, preds, target):
|
||||
# if tensor, to cpu
|
||||
@@ -16,13 +17,30 @@ class CRPSLoss(nn.Module):
|
||||
if isinstance(target, torch.Tensor):
|
||||
target = target.detach().cpu()
|
||||
|
||||
# target squeeze -1
|
||||
target = target.squeeze(-1)
|
||||
# if preds more than 2 dimensions, flatten to 2
|
||||
if len(preds.shape) > 2:
|
||||
preds = preds.reshape(-1, preds.shape[-1])
|
||||
# target will be reshaped from (1024, 96, 15) to (1024*96, 15)
|
||||
# our target (1024, 96) also needs to be reshaped to (1024*96, 1)
|
||||
target = target.reshape(-1, 1)
|
||||
|
||||
# preds shape: [batch_size, num_quantiles]
|
||||
scores = crps_ensemble(target, preds)
|
||||
# preds and target as numpy
|
||||
preds = preds.numpy()
|
||||
target = target.numpy()
|
||||
|
||||
# mean over batch
|
||||
crps = scores.mean()
|
||||
|
||||
return crps
|
||||
n_x = 101
|
||||
probs = np.linspace(0, 1, n_x)
|
||||
|
||||
spline = CubicSpline(self.quantiles, preds, axis=1)
|
||||
|
||||
imbalances = spline(probs)
|
||||
|
||||
larger_than_label = imbalances > target
|
||||
tiled_probs = np.tile(probs, (len(imbalances), 1))
|
||||
tiled_probs[larger_than_label] -= 1
|
||||
|
||||
crps_per_sample = np.trapz(tiled_probs ** 2, imbalances, axis=-1)
|
||||
crps = np.mean(crps_per_sample)
|
||||
|
||||
return crps
|
||||
Reference in New Issue
Block a user