Initial Commit

This commit is contained in:
Victor Mylle
2023-11-07 18:00:20 +00:00
commit 56c763a6f4
41 changed files with 358954 additions and 0 deletions

1
src/losses/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .pinball_loss import PinballLoss

View File

@@ -0,0 +1,33 @@
import torch
from torch import nn
class PinballLoss(nn.Module):
"""
Calculates the quantile loss function.
Attributes
----------
self.pred : torch.tensor
Predictions.
self.target : torch.tensor
Target to predict.
self.quantiles : torch.tensor
"""
def __init__(self, quantiles):
super(PinballLoss, self).__init__()
self.quantiles_tensor = quantiles
self.quantiles = quantiles.tolist()
def forward(self, pred, target):
"""
Computes the loss for the given prediction.
"""
error = target - pred
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(torch.sum(losses, dim=1))
return loss