76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
import torch
|
|
from torch import nn
|
|
import numpy as np
|
|
from scipy.interpolate import CubicSpline
|
|
|
|
|
|
class CRPSLoss(nn.Module):
|
|
def __init__(self, quantiles):
|
|
super(CRPSLoss, self).__init__()
|
|
self.quantiles = quantiles
|
|
|
|
def forward(self, preds, target):
|
|
# if tensor, to cpu
|
|
if isinstance(preds, torch.Tensor):
|
|
preds = preds.detach().cpu()
|
|
|
|
if isinstance(target, torch.Tensor):
|
|
target = target.detach().cpu()
|
|
|
|
# if preds more than 2 dimensions, flatten to 2
|
|
if len(preds.shape) > 2:
|
|
preds = preds.reshape(-1, preds.shape[-1])
|
|
# target will be reshaped from (1024, 96, 15) to (1024*96, 15)
|
|
# our target (1024, 96) also needs to be reshaped to (1024*96, 1)
|
|
target = target.reshape(-1, 1)
|
|
|
|
# preds and target as numpy
|
|
preds = preds.numpy()
|
|
target = target.numpy()
|
|
|
|
|
|
n_x = 101
|
|
probs = np.linspace(0, 1, n_x)
|
|
|
|
spline = CubicSpline(self.quantiles, preds, axis=1)
|
|
|
|
imbalances = spline(probs)
|
|
|
|
larger_than_label = imbalances > target
|
|
tiled_probs = np.tile(probs, (len(imbalances), 1))
|
|
tiled_probs[larger_than_label] -= 1
|
|
|
|
crps_per_sample = np.trapz(tiled_probs ** 2, imbalances, axis=-1)
|
|
crps = np.mean(crps_per_sample)
|
|
|
|
return crps
|
|
|
|
def crps_from_samples(samples, targets):
|
|
"""
|
|
Compute the Continuous Ranked Probability Score (CRPS) from multi-day samples and targets
|
|
using a vectorized approach with PyTorch tensors.
|
|
|
|
:param samples: (day, n_samples, n_timesteps) tensor of forecasted samples
|
|
:param targets: (day, n_timesteps) tensor of observed values
|
|
:return: (day, n_timesteps) tensor of CRPS for each timestep for each day
|
|
"""
|
|
days, n_samples, n_timesteps = samples.shape
|
|
|
|
# Reshape targets to broadcast along the samples dimension (n_samples)
|
|
targets_reshaped = targets.unsqueeze(1)
|
|
|
|
# Compute the absolute differences of forecasts and observations
|
|
abs_diff = torch.abs(samples - targets_reshaped)
|
|
# Compute the average of the absolute differences along the samples dimension
|
|
term1 = torch.mean(abs_diff, dim=1)
|
|
|
|
# Compute the pairwise absolute differences between all samples for each day
|
|
pairwise_abs_diff = torch.abs(samples.unsqueeze(2) - samples.unsqueeze(1))
|
|
# Compute the average of these differences along the sample dimensions
|
|
term2 = torch.mean(pairwise_abs_diff, dim=(1, 2)) / 2
|
|
|
|
# CRPS for each timestep for each day
|
|
crps = term1 - term2
|
|
|
|
return crps
|