Fixing some stuff
This commit is contained in:
@@ -71,7 +71,6 @@ class NrvDataset(Dataset):
|
||||
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
|
||||
self.datetime = dataframe["datetime"]
|
||||
|
||||
print(dataframe.columns)
|
||||
self.history_features, self.forecast_features = self.preprocess_data(dataframe)
|
||||
|
||||
def skip_samples(self, dataframe):
|
||||
|
||||
@@ -253,7 +253,7 @@ class DataProcessor:
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
def get_test_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
|
||||
):
|
||||
test_df = self.all_features.copy()
|
||||
|
||||
@@ -287,19 +287,19 @@ class DataProcessor:
|
||||
test_dataset = NrvDataset(
|
||||
test_df,
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
full_day_skip=self.full_day_skip or full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(test_dataset, shuffle=False)
|
||||
|
||||
def get_dataloaders(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
|
||||
):
|
||||
return self.get_train_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
), self.get_test_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
transform=transform, predict_sequence_length=predict_sequence_length, full_day_skip=full_day_skip
|
||||
)
|
||||
|
||||
def inverse_transform(self, input_data):
|
||||
|
||||
Reference in New Issue
Block a user