Added LSTM model
This commit is contained in:
@@ -12,10 +12,12 @@ class NrvDataset(Dataset):
|
||||
full_day_skip: bool = False,
|
||||
sequence_length=96,
|
||||
predict_sequence_length=96,
|
||||
lstm: bool = False,
|
||||
):
|
||||
self.data_config = data_config
|
||||
self.dataframe = dataframe
|
||||
self.full_day_skip = full_day_skip
|
||||
self.lstm = lstm
|
||||
|
||||
# reset dataframe index
|
||||
self.dataframe.reset_index(drop=True, inplace=True)
|
||||
@@ -107,19 +109,26 @@ class NrvDataset(Dataset):
|
||||
history_features = history_df[self.history_features].values
|
||||
|
||||
# combine the history features to one tensor (first one feature, then the next one, etc.)
|
||||
history_features = torch.tensor(history_features).reshape(-1)
|
||||
history_features = torch.tensor(history_features)
|
||||
|
||||
# get forecast features
|
||||
forecast_features = forecast_df[self.forecast_features].values
|
||||
forecast_features = torch.tensor(forecast_features).view(-1)
|
||||
forecast_features = torch.tensor(forecast_features)
|
||||
|
||||
# add last time feature of the history
|
||||
time_feature = history_df["time_feature"].iloc[-1]
|
||||
|
||||
## all features
|
||||
all_features = torch.cat(
|
||||
[nrv_features, history_features, forecast_features, torch.tensor([time_feature])], dim=0
|
||||
)
|
||||
if not self.lstm:
|
||||
all_features = torch.cat(
|
||||
[nrv_features, history_features.reshape(-1), forecast_features.reshape(-1), torch.tensor([time_feature])], dim=0
|
||||
)
|
||||
else:
|
||||
time_features = torch.tensor(history_df["time_feature"].values).reshape(-1, 1)
|
||||
# combine (96, ) and (96, 2) to (96, 3)
|
||||
all_features = torch.cat(
|
||||
[nrv_features.unsqueeze(1), time_features], dim=1
|
||||
)
|
||||
|
||||
# Target sequence, flattened if necessary
|
||||
nrv_target = forecast_df["nrv"].values
|
||||
@@ -133,7 +142,7 @@ class NrvDataset(Dataset):
|
||||
# all features and target to float
|
||||
all_features = all_features.float()
|
||||
|
||||
# to tensors
|
||||
# to tens&éazzaéaz"ezéors
|
||||
nrv_target = torch.tensor(nrv_target).float()
|
||||
return all_features, nrv_target, idx
|
||||
|
||||
|
||||
@@ -36,9 +36,10 @@ class DataConfig:
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self, data_config: DataConfig, path:str="./"):
|
||||
def __init__(self, data_config: DataConfig, lstm: bool = False, path:str="./"):
|
||||
self.batch_size = 2048
|
||||
self.path = path
|
||||
self.lstm = lstm
|
||||
|
||||
self.train_range = (
|
||||
-np.inf,
|
||||
@@ -204,6 +205,7 @@ class DataProcessor:
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
@@ -234,6 +236,7 @@ class DataProcessor:
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(test_dataset, shuffle=False)
|
||||
|
||||
@@ -274,7 +277,7 @@ class DataProcessor:
|
||||
predict_sequence_length=self.output_size
|
||||
)
|
||||
input, _, _ = next(iter(data_loader))
|
||||
return input.shape[-1]
|
||||
return input.shape
|
||||
|
||||
def get_time_feature_size(self):
|
||||
time_feature_size = 1
|
||||
|
||||
Reference in New Issue
Block a user