Fixed crps + more inputs

This commit is contained in:
Victor Mylle
2023-12-05 00:08:17 +00:00
parent 120b6aa5bd
commit d3bf04d68c
13 changed files with 128426 additions and 70 deletions

View File

@@ -1,12 +1,13 @@
import torch
from torch import nn
import torch
from properscoring import crps_ensemble
import numpy as np
from scipy.interpolate import CubicSpline
class CRPSLoss(nn.Module):
def __init__(self):
def __init__(self, quantiles):
super(CRPSLoss, self).__init__()
self.quantiles = quantiles
def forward(self, preds, target):
# if tensor, to cpu
@@ -16,13 +17,30 @@ class CRPSLoss(nn.Module):
if isinstance(target, torch.Tensor):
target = target.detach().cpu()
# target squeeze -1
target = target.squeeze(-1)
# if preds more than 2 dimensions, flatten to 2
if len(preds.shape) > 2:
preds = preds.reshape(-1, preds.shape[-1])
# target will be reshaped from (1024, 96, 15) to (1024*96, 15)
# our target (1024, 96) also needs to be reshaped to (1024*96, 1)
target = target.reshape(-1, 1)
# preds shape: [batch_size, num_quantiles]
scores = crps_ensemble(target, preds)
# preds and target as numpy
preds = preds.numpy()
target = target.numpy()
# mean over batch
crps = scores.mean()
return crps
n_x = 101
probs = np.linspace(0, 1, n_x)
spline = CubicSpline(self.quantiles, preds, axis=1)
imbalances = spline(probs)
larger_than_label = imbalances > target
tiled_probs = np.tile(probs, (len(imbalances), 1))
tiled_probs[larger_than_label] -= 1
crps_per_sample = np.trapz(tiled_probs ** 2, imbalances, axis=-1)
crps = np.mean(crps_per_sample)
return crps