Added non autoregressive quantile results + changing sample plots

This commit is contained in:
2024-04-19 12:35:27 +02:00
parent 98a7244995
commit 4e713ef564
15 changed files with 107 additions and 48 deletions

View File

@@ -417,7 +417,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
for i in range(10):
ax2.plot(predictions_np[i], label=f"Sample {i}")
ax2.plot(next_day_np, label="Real NRV", linewidth=3)
ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange")
ax2.legend()
ax2.set_ylim(-1500, 1500)
@@ -561,16 +561,19 @@ class NonAutoRegressiveQuantileRegression(Trainer):
outputs = outputs.reshape(-1, 96, len(self.quantiles))
outputted_samples = [
sample_from_dist(self.quantiles, output.cpu()) for _ in range(100) for output in outputs
sample_from_dist(self.quantiles, output.cpu())
for _ in range(100)
for output in outputs
]
outputted_samples = torch.tensor(outputted_samples)
inversed_outputs_samples = self.data_processor.inverse_transform(
outputted_samples
)
expanded_targets = targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96)
expanded_targets = (
targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96)
)
inversed_expanded_targets = self.data_processor.inverse_transform(
expanded_targets
)
@@ -587,7 +590,6 @@ class NonAutoRegressiveQuantileRegression(Trainer):
expanded_targets = expanded_targets.to(self.device)
inversed_expanded_targets = inversed_expanded_targets.to(self.device)
for metric in self.metrics_to_track:
if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss:
transformed_metrics[metric.__class__.__name__] += metric(
@@ -628,7 +630,9 @@ class NonAutoRegressiveQuantileRegression(Trainer):
name=metric_name, value=metric_value
)
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
def debug_plots(
self, task, train: bool, data_loader, sample_indices, epoch, final=False
):
for actual_idx, idx in sample_indices.items():
features, target, _ = data_loader.dataset[idx]
@@ -664,6 +668,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
report_interactive=False,
)
if final:
# fig to PIL image
fig.savefig(f"sample_{actual_idx}_plot.png")
task.get_logger().report_image(
title="Final Training Plot",
series=f"Sample {actual_idx}",
iteration=epoch,
image_path=f"sample_{actual_idx}_plot.png",
)
fig2.savefig(f"sample_{actual_idx}_samples_plot.png")
task.get_logger().report_image(
title="Final Training Samples Plot",
series=f"Sample {actual_idx} samples",
iteration=epoch,
image_path=f"sample_{actual_idx}_samples_plot.png",
)
plt.close()
def get_plot(