Fixing some stuff

This commit is contained in:
Victor Mylle
2023-12-30 15:22:32 +00:00
parent ef8b5f49ac
commit c26ae76951
6 changed files with 107 additions and 33 deletions

View File

@@ -253,7 +253,7 @@ class DataProcessor:
return self.get_dataloader(train_dataset, shuffle=shuffle)
def get_test_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
):
test_df = self.all_features.copy()
@@ -287,19 +287,19 @@ class DataProcessor:
test_dataset = NrvDataset(
test_df,
data_config=self.data_config,
full_day_skip=self.full_day_skip,
full_day_skip=self.full_day_skip or full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(test_dataset, shuffle=False)
def get_dataloaders(
self, transform: bool = True, predict_sequence_length: int = 96
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
):
return self.get_train_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
), self.get_test_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
transform=transform, predict_sequence_length=predict_sequence_length, full_day_skip=full_day_skip
)
def inverse_transform(self, input_data):