Added LSTM model
This commit is contained in:
@@ -36,9 +36,10 @@ class DataConfig:
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self, data_config: DataConfig, path:str="./"):
|
||||
def __init__(self, data_config: DataConfig, lstm: bool = False, path:str="./"):
|
||||
self.batch_size = 2048
|
||||
self.path = path
|
||||
self.lstm = lstm
|
||||
|
||||
self.train_range = (
|
||||
-np.inf,
|
||||
@@ -204,6 +205,7 @@ class DataProcessor:
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
@@ -234,6 +236,7 @@ class DataProcessor:
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(test_dataset, shuffle=False)
|
||||
|
||||
@@ -274,7 +277,7 @@ class DataProcessor:
|
||||
predict_sequence_length=self.output_size
|
||||
)
|
||||
input, _, _ = next(iter(data_loader))
|
||||
return input.shape[-1]
|
||||
return input.shape
|
||||
|
||||
def get_time_feature_size(self):
|
||||
time_feature_size = 1
|
||||
|
||||
Reference in New Issue
Block a user