Fixed small summary with model architectures until now

This commit is contained in:
Victor Mylle
2023-11-30 21:53:35 +00:00
parent eba10c8f83
commit 120b6aa5bd
23 changed files with 402 additions and 185 deletions

View File

@@ -15,6 +15,7 @@ class AutoRegressiveTrainer(Trainer):
def __init__(
self,
model: torch.nn.Module,
input_dim: tuple,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
data_processor: DataProcessor,
@@ -23,6 +24,7 @@ class AutoRegressiveTrainer(Trainer):
):
super().__init__(
model=model,
input_dim=input_dim,
optimizer=optimizer,
criterion=criterion,
data_processor=data_processor,

View File

@@ -48,7 +48,7 @@ class ProbabilisticBaselineTrainer(Trainer):
predict_sequence_length=96
)
for inputs, _ in train_loader:
for inputs, _, _ in train_loader:
for i in range(96):
time_steps[i].extend(inputs[:, i].numpy())
@@ -80,7 +80,7 @@ class ProbabilisticBaselineTrainer(Trainer):
raise
def log_final_metrics(self, task, dataloader, quantile_values, train: bool = True):
metric = CRPSLoss(self.quantiles)
metric = CRPSLoss()
crps_values = []
crps_inversed_values = []
@@ -147,6 +147,9 @@ class ProbabilisticBaselineTrainer(Trainer):
def plot_quantiles(self, quantile_values):
fig = go.Figure()
# inverse transform quantile_values
quantile_values = self.data_processor.inverse_transform(quantile_values)
for i, q in enumerate(self.quantiles):
values_for_quantile = quantile_values[:, i]
fig.add_trace(
@@ -159,7 +162,8 @@ class ProbabilisticBaselineTrainer(Trainer):
)
fig.update_layout(title="Quantile Values")
fig.update_yaxes(range=[-1, 1])
fig.update_layout(height=600)
return fig

View File

@@ -60,6 +60,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
def __init__(
self,
model: torch.nn.Module,
input_dim: tuple,
optimizer: torch.optim.Optimizer,
data_processor: DataProcessor,
quantiles: list,
@@ -72,6 +73,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
criterion = PinballLoss(quantiles=quantiles)
super().__init__(
model=model,
input_dim=input_dim,
optimizer=optimizer,
criterion=criterion,
data_processor=data_processor,
@@ -192,7 +194,10 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
prev_features = prev_features.to(self.device)
targets = targets.to(self.device)
initial_sequence = prev_features[:, :96]
if len(list(prev_features.shape)) == 2:
initial_sequence = prev_features[:, :96]
else:
initial_sequence = prev_features[:, :, 0]
target_full = targets[:, 0].unsqueeze(1) # (batch_size, 1)
with torch.no_grad():
@@ -206,22 +211,37 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
predictions_full = new_predictions_full.unsqueeze(1)
for i in range(sequence_length - 1):
new_features = torch.cat(
(prev_features[:, 1:96], samples), dim=1
) # (batch_size, 96)
if len(list(prev_features.shape)) == 2:
new_features = torch.cat(
(prev_features[:, 1:96], samples), dim=1
) # (batch_size, 96)
new_features = new_features.float()
new_features = new_features.float()
other_features, new_targets = dataset.get_batch_autoregressive(
np.array(idx_batch) + i + 1
) # (batch_size, new_features)
other_features, new_targets = dataset.get_batch_autoregressive(
np.array(idx_batch) + i + 1
) # (batch_size, new_features)
if other_features is not None:
prev_features = torch.cat(
(new_features.to(self.device), other_features.to(self.device)), dim=1
) # (batch_size, 96 + new_features)
else:
prev_features = new_features
if other_features is not None:
prev_features = torch.cat(
(new_features.to(self.device), other_features.to(self.device)), dim=1
) # (batch_size, 96 + new_features)
else:
prev_features = new_features
other_features, new_targets = dataset.get_batch_autoregressive(
np.array(idx_batch) + i + 1
) # (batch_size, 1, new_features)
# change the other_features nrv based on the samples
other_features[:, 0, 0] = samples.squeeze(-1)
# make sure on same device
other_features = other_features.to(self.device)
prev_features = prev_features.to(self.device)
prev_features = torch.cat(
(prev_features[:, 1:, :], other_features), dim=1
) # (batch_size, 96, new_features)
target_full = torch.cat(
(target_full, new_targets.to(self.device)), dim=1

View File

@@ -6,18 +6,20 @@ import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
from clearml.config import running_remotely
from torchinfo import summary
class Trainer:
def __init__(
self,
model: torch.nn.Module,
input_dim: tuple,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
data_processor: DataProcessor,
device: torch.device,
debug: bool = True,
):
self.input_dim = input_dim
self.model = model
self.optimizer = optimizer
self.criterion = criterion
@@ -70,6 +72,8 @@ class Trainer:
task.add_tags(self.optimizer.__class__.__name__)
task.add_tags(self.__class__.__name__)
task.set_configuration_object("model", str(summary(self.model, self.input_dim)))
self.optimizer.name = self.optimizer.__class__.__name__
self.criterion.name = self.criterion.__class__.__name__