Fixed crps + more inputs
This commit is contained in:
@@ -94,7 +94,6 @@ class AutoRegressiveTrainer(Trainer):
|
||||
|
||||
target_full.append(target)
|
||||
with torch.no_grad():
|
||||
print(prev_features.shape)
|
||||
prediction = self.model(prev_features.unsqueeze(0))
|
||||
predictions_full.append(prediction.squeeze(-1))
|
||||
|
||||
@@ -107,8 +106,6 @@ class AutoRegressiveTrainer(Trainer):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
print(new_features.shape)
|
||||
|
||||
# get the other needed features
|
||||
other_features, new_target = data_loader.dataset.random_day_autoregressive(
|
||||
idx + i + 1
|
||||
|
||||
@@ -80,7 +80,7 @@ class ProbabilisticBaselineTrainer(Trainer):
|
||||
raise
|
||||
|
||||
def log_final_metrics(self, task, dataloader, quantile_values, train: bool = True):
|
||||
metric = CRPSLoss()
|
||||
metric = CRPSLoss(quantiles=self.quantiles)
|
||||
|
||||
crps_values = []
|
||||
crps_inversed_values = []
|
||||
|
||||
@@ -270,7 +270,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
)
|
||||
|
||||
def plot_quantile_percentages(
|
||||
self, task, data_loader, train: bool = True, iteration: int = None
|
||||
self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False
|
||||
):
|
||||
quantiles = self.quantiles
|
||||
total = 0
|
||||
@@ -278,16 +278,34 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for inputs, targets, _ in data_loader:
|
||||
inputs = inputs.to(self.device)
|
||||
output = self.model(inputs).cpu().numpy()
|
||||
targets = targets.squeeze(-1).cpu().numpy()
|
||||
total_samples = len(data_loader.dataset) - 96
|
||||
|
||||
for inputs, targets, idx_batch in data_loader:
|
||||
idx_batch = [idx for idx in idx_batch if idx < total_samples]
|
||||
|
||||
if full_day:
|
||||
_, outputs, samples, targets = self.auto_regressive(
|
||||
data_loader.dataset, idx_batch=idx_batch
|
||||
)
|
||||
# outputs: (batch, sequence_length, num_quantiles)
|
||||
# targets: (batch, sequence_length, 1)
|
||||
|
||||
# reshape to (batch_size * sequence_length, num_quantiles)
|
||||
outputs = outputs.reshape(-1, len(quantiles))
|
||||
targets = targets.reshape(-1)
|
||||
|
||||
# to cpu
|
||||
outputs = outputs.cpu().numpy()
|
||||
targets = targets.cpu().numpy()
|
||||
|
||||
else:
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = self.model(inputs).cpu().numpy() # (batch_size, num_quantiles)
|
||||
targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1)
|
||||
|
||||
# output shape: (batch_size, num_quantiles)
|
||||
# target shape: (batch_size, 1)
|
||||
for i, q in enumerate(quantiles):
|
||||
quantile_counter[q] += np.sum(
|
||||
targets < output[:, i]
|
||||
targets < outputs[:, i]
|
||||
)
|
||||
|
||||
total += len(targets)
|
||||
@@ -322,18 +340,19 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
) # Format the number as a percentage
|
||||
|
||||
series_name = "Training Set" if train else "Test Set"
|
||||
full_day_str = "Full Day" if full_day else "Single Step"
|
||||
|
||||
# Adding labels and title
|
||||
ax.set_xlabel("Quantile")
|
||||
ax.set_ylabel("Fraction of data under quantile forecast")
|
||||
ax.set_title(f"Quantile Performance Comparison ({series_name})")
|
||||
ax.set_title(f"{series_name} {full_day_str} Quantile Performance Comparison")
|
||||
ax.set_xticks(index + bar_width / 2)
|
||||
ax.set_xticklabels(quantiles)
|
||||
ax.legend()
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Quantile Performance Comparison",
|
||||
series=series_name,
|
||||
series=f"{series_name} {full_day_str}",
|
||||
report_image=True,
|
||||
figure=plt,
|
||||
iteration=iteration,
|
||||
|
||||
@@ -166,11 +166,17 @@ class Trainer:
|
||||
|
||||
if hasattr(self, "plot_quantile_percentages"):
|
||||
self.plot_quantile_percentages(
|
||||
task, train_loader, True, epoch
|
||||
task, train_loader, True, epoch, False
|
||||
)
|
||||
# self.plot_quantile_percentages(
|
||||
# task, train_loader, True, epoch, True
|
||||
# )
|
||||
self.plot_quantile_percentages(
|
||||
task, test_loader, False, epoch
|
||||
task, test_loader, False, epoch, False
|
||||
)
|
||||
# self.plot_quantile_percentages(
|
||||
# task, test_loader, False, epoch, True
|
||||
# )
|
||||
|
||||
if task:
|
||||
self.finish_training(task=task)
|
||||
|
||||
Reference in New Issue
Block a user