Added new training scripts

This commit is contained in:
Victor Mylle
2023-11-27 14:55:22 +00:00
parent 5e87165dbb
commit c1152ff96c
7 changed files with 37 additions and 36 deletions

View File

@@ -18,6 +18,7 @@ class CRPSLoss(nn.Module):
# target = target.unsqueeze(-1)
mask = (preds > target).float()
self.quantiles_tensor = self.quantiles_tensor.to(preds.device)
test = self.quantiles_tensor - mask
# square them
test = test * test

View File

@@ -9,8 +9,9 @@ class PinballLoss(nn.Module):
def forward(self, pred, target):
error = target - pred
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
quantiles = self.quantiles_tensor.to(error.device)
upper = quantiles * error
lower = (quantiles - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(torch.mean(losses, dim=0))
return loss
@@ -26,8 +27,10 @@ class NonAutoRegressivePinballLoss(nn.Module):
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))
target_expanded = target.unsqueeze(2)
error = target_expanded - pred
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
quantiles = self.quantiles_tensor.to(error.device)
upper = quantiles * error
lower = (quantiles - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(losses)
return loss