Initial Commit
This commit is contained in:
1
src/losses/__init__.py
Normal file
1
src/losses/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .pinball_loss import PinballLoss
|
||||
33
src/losses/pinball_loss.py
Normal file
33
src/losses/pinball_loss.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class PinballLoss(nn.Module):
|
||||
"""
|
||||
Calculates the quantile loss function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
self.pred : torch.tensor
|
||||
Predictions.
|
||||
self.target : torch.tensor
|
||||
Target to predict.
|
||||
self.quantiles : torch.tensor
|
||||
"""
|
||||
def __init__(self, quantiles):
|
||||
super(PinballLoss, self).__init__()
|
||||
self.quantiles_tensor = quantiles
|
||||
self.quantiles = quantiles.tolist()
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""
|
||||
Computes the loss for the given prediction.
|
||||
"""
|
||||
|
||||
error = target - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(torch.sum(losses, dim=1))
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user