Fixing some stuff

This commit is contained in:
Victor Mylle
2023-12-30 15:22:32 +00:00
parent ef8b5f49ac
commit c26ae76951
6 changed files with 107 additions and 33 deletions

View File

@@ -71,7 +71,6 @@ class NrvDataset(Dataset):
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
self.datetime = dataframe["datetime"]
print(dataframe.columns)
self.history_features, self.forecast_features = self.preprocess_data(dataframe)
def skip_samples(self, dataframe):

View File

@@ -253,7 +253,7 @@ class DataProcessor:
return self.get_dataloader(train_dataset, shuffle=shuffle)
def get_test_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
):
test_df = self.all_features.copy()
@@ -287,19 +287,19 @@ class DataProcessor:
test_dataset = NrvDataset(
test_df,
data_config=self.data_config,
full_day_skip=self.full_day_skip,
full_day_skip=self.full_day_skip or full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(test_dataset, shuffle=False)
def get_dataloaders(
self, transform: bool = True, predict_sequence_length: int = 96
self, transform: bool = True, predict_sequence_length: int = 96, full_day_skip: bool = False
):
return self.get_train_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
), self.get_test_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
transform=transform, predict_sequence_length=predict_sequence_length, full_day_skip=full_day_skip
)
def inverse_transform(self, input_data):