Added non autoregressive quantile results + changing sample plots
This commit is contained in:
@@ -417,7 +417,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
for i in range(10):
|
||||
ax2.plot(predictions_np[i], label=f"Sample {i}")
|
||||
|
||||
ax2.plot(next_day_np, label="Real NRV", linewidth=3)
|
||||
ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange")
|
||||
ax2.legend()
|
||||
|
||||
ax2.set_ylim(-1500, 1500)
|
||||
@@ -561,16 +561,19 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
outputs = outputs.reshape(-1, 96, len(self.quantiles))
|
||||
|
||||
outputted_samples = [
|
||||
sample_from_dist(self.quantiles, output.cpu()) for _ in range(100) for output in outputs
|
||||
sample_from_dist(self.quantiles, output.cpu())
|
||||
for _ in range(100)
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
|
||||
outputted_samples = torch.tensor(outputted_samples)
|
||||
inversed_outputs_samples = self.data_processor.inverse_transform(
|
||||
outputted_samples
|
||||
)
|
||||
|
||||
expanded_targets = targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96)
|
||||
expanded_targets = (
|
||||
targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96)
|
||||
)
|
||||
inversed_expanded_targets = self.data_processor.inverse_transform(
|
||||
expanded_targets
|
||||
)
|
||||
@@ -587,7 +590,6 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
expanded_targets = expanded_targets.to(self.device)
|
||||
inversed_expanded_targets = inversed_expanded_targets.to(self.device)
|
||||
|
||||
|
||||
for metric in self.metrics_to_track:
|
||||
if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss:
|
||||
transformed_metrics[metric.__class__.__name__] += metric(
|
||||
@@ -628,7 +630,9 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
name=metric_name, value=metric_value
|
||||
)
|
||||
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
def debug_plots(
|
||||
self, task, train: bool, data_loader, sample_indices, epoch, final=False
|
||||
):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
|
||||
@@ -664,6 +668,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
if final:
|
||||
# fig to PIL image
|
||||
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||
task.get_logger().report_image(
|
||||
title="Final Training Plot",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
image_path=f"sample_{actual_idx}_plot.png",
|
||||
)
|
||||
|
||||
fig2.savefig(f"sample_{actual_idx}_samples_plot.png")
|
||||
task.get_logger().report_image(
|
||||
title="Final Training Samples Plot",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
image_path=f"sample_{actual_idx}_samples_plot.png",
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def get_plot(
|
||||
|
||||
Reference in New Issue
Block a user