Added LSTM model

This commit is contained in:
Victor Mylle
2023-11-28 22:27:15 +00:00
parent ffa19592f9
commit eba10c8f83
6 changed files with 117 additions and 38 deletions

View File

@@ -12,10 +12,12 @@ class NrvDataset(Dataset):
full_day_skip: bool = False,
sequence_length=96,
predict_sequence_length=96,
lstm: bool = False,
):
self.data_config = data_config
self.dataframe = dataframe
self.full_day_skip = full_day_skip
self.lstm = lstm
# reset dataframe index
self.dataframe.reset_index(drop=True, inplace=True)
@@ -107,19 +109,26 @@ class NrvDataset(Dataset):
history_features = history_df[self.history_features].values
# combine the history features to one tensor (first one feature, then the next one, etc.)
history_features = torch.tensor(history_features).reshape(-1)
history_features = torch.tensor(history_features)
# get forecast features
forecast_features = forecast_df[self.forecast_features].values
forecast_features = torch.tensor(forecast_features).view(-1)
forecast_features = torch.tensor(forecast_features)
# add last time feature of the history
time_feature = history_df["time_feature"].iloc[-1]
## all features
all_features = torch.cat(
[nrv_features, history_features, forecast_features, torch.tensor([time_feature])], dim=0
)
if not self.lstm:
all_features = torch.cat(
[nrv_features, history_features.reshape(-1), forecast_features.reshape(-1), torch.tensor([time_feature])], dim=0
)
else:
time_features = torch.tensor(history_df["time_feature"].values).reshape(-1, 1)
# combine (96, ) and (96, 2) to (96, 3)
all_features = torch.cat(
[nrv_features.unsqueeze(1), time_features], dim=1
)
# Target sequence, flattened if necessary
nrv_target = forecast_df["nrv"].values
@@ -133,7 +142,7 @@ class NrvDataset(Dataset):
# all features and target to float
all_features = all_features.float()
# to tensors
# to tens&éazzaéaz"ezéors
nrv_target = torch.tensor(nrv_target).float()
return all_features, nrv_target, idx

View File

@@ -36,9 +36,10 @@ class DataConfig:
class DataProcessor:
def __init__(self, data_config: DataConfig, path:str="./"):
def __init__(self, data_config: DataConfig, lstm: bool = False, path:str="./"):
self.batch_size = 2048
self.path = path
self.lstm = lstm
self.train_range = (
-np.inf,
@@ -204,6 +205,7 @@ class DataProcessor:
data_config=self.data_config,
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(train_dataset, shuffle=shuffle)
@@ -234,6 +236,7 @@ class DataProcessor:
data_config=self.data_config,
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(test_dataset, shuffle=False)
@@ -274,7 +277,7 @@ class DataProcessor:
predict_sequence_length=self.output_size
)
input, _, _ = next(iter(data_loader))
return input.shape[-1]
return input.shape
def get_time_feature_size(self):
time_feature_size = 1