Added LSTM model

This commit is contained in:
Victor Mylle
2023-11-28 22:27:15 +00:00
parent ffa19592f9
commit eba10c8f83
6 changed files with 117 additions and 38 deletions

View File

@@ -36,9 +36,10 @@ class DataConfig:
class DataProcessor:
def __init__(self, data_config: DataConfig, path:str="./"):
def __init__(self, data_config: DataConfig, lstm: bool = False, path:str="./"):
self.batch_size = 2048
self.path = path
self.lstm = lstm
self.train_range = (
-np.inf,
@@ -204,6 +205,7 @@ class DataProcessor:
data_config=self.data_config,
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(train_dataset, shuffle=shuffle)
@@ -234,6 +236,7 @@ class DataProcessor:
data_config=self.data_config,
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(test_dataset, shuffle=False)
@@ -274,7 +277,7 @@ class DataProcessor:
predict_sequence_length=self.output_size
)
input, _, _ = next(iter(data_loader))
return input.shape[-1]
return input.shape
def get_time_feature_size(self):
time_feature_size = 1