Sped up sampling 20x

This commit is contained in:
Victor Mylle
2023-11-25 18:09:42 +00:00
parent 5de3f64a1a
commit 300f268286
10 changed files with 498 additions and 238 deletions

View File

@@ -167,7 +167,10 @@ class DataProcessor:
)
def get_train_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
self,
transform: bool = True,
predict_sequence_length: int = 96,
shuffle: bool = True,
):
train_df = self.all_features.copy()
@@ -194,7 +197,7 @@ class DataProcessor:
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
)
return self.get_dataloader(train_dataset)
return self.get_dataloader(train_dataset, shuffle=shuffle)
def get_test_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
@@ -262,5 +265,5 @@ class DataProcessor:
data_loader = self.get_train_dataloader(
predict_sequence_length=self.output_size
)
input, _ = next(iter(data_loader))
input, _, _ = next(iter(data_loader))
return input.shape[-1]